[
  {
    "path": ".clang-format",
    "content": "BasedOnStyle: Google\nUseTab: Never\nIndentWidth: 4\nColumnLimit: 140\nAccessModifierOffset: -4\n\n# Force pointers to the type for C++.\nDerivePointerAlignment: false\nPointerAlignment: Left\nReferenceAlignment: Left\nAllowShortFunctionsOnASingleLine: Inline\nAllowShortIfStatementsOnASingleLine: false\nAllowShortLoopsOnASingleLine: false\nAlignOperands: false\nBreakBeforeBinaryOperators: None\nCpp11BracedListStyle: true\nContinuationIndentWidth: 4\n\nBinPackArguments: false\nBinPackParameters: false\n\n"
  },
  {
    "path": ".github/workflows/format.yml",
    "content": "name: Code Format Check\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  format-check:\n    runs-on: ubuntu-latest\n\n    steps:\n    - name: Checkout source\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 0 \n\n    - name: Setup environment\n      run: |\n        sudo apt-get update\n        sudo apt-get install -y bash\n\n    - name: Run format.sh\n      run: |\n        bash ./format.sh\n\n    # If format.sh return non-zero, GitHub Actions will mark it as failure."
  },
  {
    "path": ".gitignore",
    "content": "compile_commands.json\n.idea\n.DS_Store\n*.pyc\nbuild/\n.cache/\n.vscode/\n*/cmake-build-*/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 DeepSeek\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# DeepEP\n\nDeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.\n\nTo align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.\n\nFor latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.\n\nNotice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.\n\n## Performance\n\n### Normal kernels with NVLink and RDMA forwarding\n\nWe test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).\n\n|   Type    | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |\n|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|\n| Intranode |      8       |  153 GB/s (NVLink)   |      8      |  158 GB/s (NVLink)   |\n| Internode |      16      |    43 GB/s (RDMA)    |     16      |    43 GB/s (RDMA)    |\n| Internode |      32      |    58 GB/s (RDMA)    |     32      |    57 GB/s (RDMA)    |\n| Internode |      64      |    51 GB/s (RDMA)    |     64      |    50 GB/s (RDMA)    |\n\n**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution!\n\n### Low-latency kernels with pure RDMA\n\nWe test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).\n\n| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |\n|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|\n|      8       |  77 us  |    98 GB/s     |      8      | 114 us  |    127 GB/s    |\n|      16      | 118 us  |    63 GB/s     |     16      | 195 us  |    74 GB/s     |\n|      32      | 155 us  |    48 GB/s     |     32      | 273 us  |    53 GB/s     |\n|      64      | 173 us  |    43 GB/s     |     64      | 314 us  |    46 GB/s     |\n|     128      | 192 us  |    39 GB/s     |     128     | 369 us  |    39 GB/s     |\n|     256      | 194 us  |    39 GB/s     |     256     | 360 us  |    40 GB/s     |\n\n**News (2025.06.05)**: low-latency kernels now leverage NVLink as much as possible, see [#173](https://github.com/deepseek-ai/DeepEP/pull/173) for more details. Thanks for the contribution!\n\n## Quick start\n\n### Requirements\n\n- Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support\n- Python 3.8 and above\n- CUDA version\n  - CUDA 11.0 and above for SM80 GPUs\n  - CUDA 12.3 and above for SM90 GPUs\n- PyTorch 2.1 and above\n- NVLink for intranode communication\n- RDMA network for internode communication\n\n### Download and install NVSHMEM dependency\n\nDeepEP also depends on NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.\n\n### Development\n\n```bash\n# Build and make symbolic links for SO files\nNVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build\n# You may modify the specific SO names according to your own platform\nln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so\n\n# Run test cases\n# NOTES: you may modify the `init_dist` function in `tests/utils.py`\n# according to your own cluster settings, and launch into multiple nodes\npython tests/test_intranode.py\npython tests/test_internode.py\npython tests/test_low_latency.py\n```\n\n### Installation\n\n```bash\nNVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install\n```\n\n#### Installation environment variables\n\n- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified\n- `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11\n- `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST=\"9.0\"`\n- `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details\n\nThen, import `deep_ep` in your Python project, and enjoy!\n\n## Network configurations\n\nDeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.\n\n### Traffic isolation\n\nTraffic isolation is supported by InfiniBand through Virtual Lanes (VL).\n\nTo prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:\n\n- workloads using normal kernels\n- workloads using low-latency kernels\n- other workloads\n\nFor DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.\n\n### Adaptive routing\n\nAdaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:\n\n- enable adaptive routing in environments with heavy network loads\n- use static routing in environments with light network loads\n\n### Congestion control\n\nCongestion control is disabled as we have not observed significant congestion in our production environment.\n\n## Interfaces and examples\n\n### Example use in model training or inference prefilling\n\nThe normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.\n\n```python\nimport torch\nimport torch.distributed as dist\nfrom typing import List, Tuple, Optional, Union\n\nfrom deep_ep import Buffer, EventOverlap\n\n# Communication buffer (will allocate at runtime)\n_buffer: Optional[Buffer] = None\n\n# Set the number of SMs to use\n# NOTES: this is a static variable\nBuffer.set_num_sms(24)\n\n\n# You may call this function at the framework initialization\ndef get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:\n    global _buffer\n\n    # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests\n    num_nvl_bytes, num_rdma_bytes = 0, 0\n    for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):\n        num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)\n        num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)\n\n    # Allocate a buffer if not existed or not enough buffer size\n    if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:\n        _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)\n    return _buffer\n\n\ndef get_hidden_bytes(x: torch.Tensor) -> int:\n    t = x[0] if isinstance(x, tuple) else x\n    return t.size(1) * max(t.element_size(), 2)\n\n\ndef dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n                     topk_idx: torch.Tensor, topk_weights: torch.Tensor,\n                     num_experts: int, previous_event: Optional[EventOverlap] = None) -> \\\n        Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:\n    # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency\n    # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please\n    # refer to the docs of `Buffer.dispatch`\n    global _buffer\n\n    # Calculate layout before actual dispatch\n    num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \\\n        _buffer.get_dispatch_layout(topk_idx, num_experts,\n                                    previous_event=previous_event, async_finish=True,\n                                    allocate_on_comm_stream=previous_event is not None)\n    # Do MoE dispatch\n    # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph\n    # Unless you specify `num_worst_tokens`, but this flag is for intranode only\n    # For more advanced usages, please refer to the docs of the `dispatch` function\n    recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \\\n        _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,\n                         num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,\n                         is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,\n                         previous_event=previous_event, async_finish=True,\n                         allocate_on_comm_stream=True)\n    # For event management, please refer to the docs of the `EventOverlap` class\n    return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event\n\n\ndef dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \\\n        Tuple[torch.Tensor, torch.Tensor, EventOverlap]:\n    global _buffer\n\n    # The backward process of MoE dispatch is actually a combine\n    # For more advanced usages, please refer to the docs of the `combine` function\n    combined_grad_x, combined_grad_recv_topk_weights, event = \\\n        _buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)\n\n    # For event management, please refer to the docs of the `EventOverlap` class\n    return combined_grad_x, combined_grad_recv_topk_weights, event\n\n\ndef combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \\\n        Tuple[torch.Tensor, EventOverlap]:\n    global _buffer\n\n    # Do MoE combine\n    # For more advanced usages, please refer to the docs of the `combine` function\n    combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,\n                                           allocate_on_comm_stream=previous_event is not None)\n\n    # For event management, please refer to the docs of the `EventOverlap` class\n    return combined_x, event\n\n\ndef combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n                     handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \\\n        Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:\n    global _buffer\n\n    # The backward process of MoE combine is actually a dispatch\n    # For more advanced usages, please refer to the docs of the `dispatch` function\n    grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,\n                                                 previous_event=previous_event,\n                                                 allocate_on_comm_stream=previous_event is not None)\n\n    # For event management, please refer to the docs of the `EventOverlap` class\n    return grad_x, event\n```\n\nMoreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.\n\n![normal](figures/normal.png)\n\n### Example use in inference decoding\n\nThe low latency kernels can be used in the inference decoding phase as the below example code shows.\n\n```python\nimport torch\nimport torch.distributed as dist\nfrom typing import Tuple, Optional\n\nfrom deep_ep import Buffer\n\n# Communication buffer (will allocate at runtime)\n# NOTES: there is no SM control API for the low-latency kernels\n_buffer: Optional[Buffer] = None\n\n\n# You may call this function at the framework initialization\ndef get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:\n    # NOTES: the low-latency mode will consume much more space than the normal mode\n    # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256\n    global _buffer\n    num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)\n\n    # Allocate a buffer if not existed or not enough buffer size\n    if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:\n        # NOTES: for the best performance, the QP number **must** be equal to the number of the local experts\n        assert num_experts % group.size() == 0\n        _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())\n    return _buffer\n\n\ndef low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):\n    global _buffer\n\n    # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)\n    recv_hidden_states, recv_expert_count, handle, event, hook = \\\n        _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,\n                                     async_finish=False, return_recv_hook=True)\n\n    # NOTES: the actual tensor will not be received only if you call `hook()`,\n    # it is useful for double-batch overlapping, but **without any SM occupation**\n    # If you don't want to overlap, please set `return_recv_hook=False`\n    # Later, you can use our GEMM library to do the computation with this specific format\n    return recv_hidden_states, recv_expert_count, handle, event, hook\n\n\ndef low_latency_combine(hidden_states: torch.Tensor,\n                        topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):\n    global _buffer\n\n    # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)\n    combined_hidden_states, event_overlap, hook = \\\n        _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,\n                                    async_finish=False, return_recv_hook=True)\n\n    # NOTES: the same behavior as described in the dispatch kernel\n    return combined_hidden_states, event_overlap, hook\n```\n\nFor two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.\n\n![low-latency](figures/low-latency.png)\n\n## Roadmap\n\n- [x] AR support\n- [x] Refactor low-latency mode AR code\n- [x] A100 support (intranode only)\n- [x] Support BF16 for the low-latency dispatch kernel\n- [x] Support NVLink protocol for intranode low-latency kernels\n- [ ] TMA copy instead of LD/ST\n  - [x] Intranode kernels\n  - [ ] Internode kernels\n  - [ ] Low-latency kernels\n- [ ] SM-free kernels and refactors\n- [ ] Fully remove undefined-behavior PTX instructions\n\n## Notices\n\n#### Easier potential overall design\n\nThe current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39.\n\n#### Undefined-behavior PTX usage\n\n- For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1.\n- Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery.\n- If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.\n\n#### Auto-tuning on your cluster\n\nFor better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.\n\n## License\n\nThis code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).\n\n## Experimental Branches\n\n- [Zero-copy](https://github.com/deepseek-ai/DeepEP/pull/453)\n  - Removing the copy between PyTorch tensors and communication buffers, which reduces the SM usages significantly for normal kernels\n  - This PR is authored by **Tencent Network Platform Department**\n- [Eager](https://github.com/deepseek-ai/DeepEP/pull/437)\n  - Using a low-latency protocol removes the extra RTT latency introduced by RDMA atomic OPs\n- [Hybrid-EP](https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep)\n  - A new backend implementation using TMA instructions for minimal SM usage and larger NVLink domain support\n  - Fine-grained communication-computation overlap for single-batch scenarios\n  - PCIe kernel support for non-NVLink environments\n  - NVFP4 data type support\n- [AntGroup-Opt](https://github.com/deepseek-ai/DeepEP/tree/antgroup-opt)\n  - This optimization series is authored by **AntGroup Network Platform Department**\n  - [Normal-SMFree](https://github.com/deepseek-ai/DeepEP/pull/347) Eliminating SM from RDMA path by decoupling comm-kernel execution from NIC token transfer, freeing SMs for compute\n  - [LL-SBO](https://github.com/deepseek-ai/DeepEP/pull/483) Overlapping Down GEMM computation with Combine Send communication via signaling mechanism to reduce end-to-end latency\n  - [LL-Layered](https://github.com/deepseek-ai/DeepEP/pull/500) Optimizing cross-node LL operator communication using rail-optimized forwarding and data merging to reduce latency\n- [Mori-EP](https://github.com/deepseek-ai/DeepEP/tree/mori-ep)\n  - ROCm/AMD GPU support powered by [MORI](https://github.com/ROCm/mori) backend (low-latency mode)\n\n## Community Forks\n\n- [uccl/uccl-ep](https://github.com/uccl-project/uccl/tree/main/ep) - Enables running DeepEP on heterogeneous GPUs (e.g., Nvidia, AMD) and NICs (e.g., EFA, Broadcom, CX7)\n- [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport\n- [antgroup/DeepXTrace](https://github.com/antgroup/DeepXTrace) - A diagnostic analyzer for efficient and precise localization of slow ranks\n- [ROCm/mori](https://github.com/ROCm/mori) - AMD's next-generation communication library for performance-critical AI workloads (e.g., Wide EP, KVCache transfer, Collectives)\n\n## Citation\n\nIf you use this codebase or otherwise find our work valuable, please cite:\n\n```bibtex\n@misc{deepep2025,\n      title={DeepEP: an efficient expert-parallel communication library},\n      author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},\n      year={2025},\n      publisher = {GitHub},\n      howpublished = {\\url{https://github.com/deepseek-ai/DeepEP}},\n}\n```\n"
  },
  {
    "path": "csrc/CMakeLists.txt",
    "content": "# NOTES: this CMake is only for debugging; for setup, please use Torch extension\ncmake_minimum_required(VERSION 3.10)\nproject(deep_ep LANGUAGES CUDA CXX)\nset(CMAKE_VERBOSE_MAKEFILE ON)\n\nset(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -O3 -fPIC\")\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O3 -fPIC\")\nset(CUDA_SEPARABLE_COMPILATION ON)\nlist(APPEND CUDA_NVCC_FLAGS \"-DENABLE_FAST_DEBUG\")\nlist(APPEND CUDA_NVCC_FLAGS \"-O3\")\nlist(APPEND CUDA_NVCC_FLAGS \"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage\")\n\nset(USE_SYSTEM_NVTX on)\nset(CUDA_ARCH_LIST \"9.0\" CACHE STRING \"List of CUDA architectures to compile\")\nset(TORCH_CUDA_ARCH_LIST \"${CUDA_ARCH_LIST}\")\n\nfind_package(CUDAToolkit REQUIRED)\nfind_package(pybind11 REQUIRED)\nfind_package(Torch REQUIRED)\nfind_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)\n\nadd_library(nvshmem ALIAS nvshmem::nvshmem)\nadd_library(nvshmem_host ALIAS nvshmem::nvshmem_host)\nadd_library(nvshmem_device ALIAS nvshmem::nvshmem_device)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CUDA_STANDARD 17)\n\ninclude_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})\nlink_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})\n\nadd_subdirectory(kernels)\n\n# Link CPP and CUDA together\npybind11_add_module(deep_ep_cpp deep_ep.cpp)\ntarget_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)\n"
  },
  {
    "path": "csrc/config.hpp",
    "content": "#pragma once\n\n#include \"kernels/api.cuh\"\n#include \"kernels/exception.cuh\"\n\nnamespace deep_ep {\n\ntemplate <typename dtype_t>\ndtype_t ceil_div(dtype_t a, dtype_t b) {\n    return (a + b - 1) / b;\n}\n\ntemplate <typename dtype_t>\ndtype_t align_up(dtype_t a, dtype_t b) {\n    return ceil_div<dtype_t>(a, b) * b;\n}\n\ntemplate <typename dtype_t>\ndtype_t align_down(dtype_t a, dtype_t b) {\n    return a / b * b;\n}\n\nstruct Config {\n    int num_sms;\n    int num_max_nvl_chunked_send_tokens;\n    int num_max_nvl_chunked_recv_tokens;\n    int num_max_rdma_chunked_send_tokens;\n    int num_max_rdma_chunked_recv_tokens;\n\n    Config(int num_sms,\n           int num_max_nvl_chunked_send_tokens,\n           int num_max_nvl_chunked_recv_tokens,\n           int num_max_rdma_chunked_send_tokens,\n           int num_max_rdma_chunked_recv_tokens)\n        : num_sms(num_sms),\n          num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),\n          num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),\n          num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),\n          num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {\n        EP_HOST_ASSERT(num_sms >= 0);\n        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);\n        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);\n        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);\n\n        // Ceil up RDMA buffer size\n        this->num_max_rdma_chunked_recv_tokens = align_up<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);\n        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);\n        // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push\n        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);\n    }\n\n    size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {\n        // Below are some assumptions\n        // TODO: add assertions\n        constexpr int kNumMaxTopK = 128;\n        constexpr int kNumMaxScales = 128;\n        EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);\n        EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);\n        const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);\n        const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);\n        const int num_channels = num_sms / 2;\n\n        size_t num_bytes = 0;\n        num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);\n        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;\n#ifndef DISABLE_NVSHMEM\n        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();\n#endif\n        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t);\n        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);\n        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);\n        num_bytes = ((num_bytes + 127) / 128) * 128;\n        return num_bytes;\n    }\n\n    size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {\n#ifndef DISABLE_NVSHMEM\n        // Legacy mode\n        if (num_ranks <= NUM_MAX_NVL_PEERS)\n            return 0;\n\n        // Below are some assumptions\n        // TODO: add assertions\n        constexpr int kNumMaxTopK = 128;\n        constexpr int kNumMaxScales = 128;\n        EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);\n        EP_HOST_ASSERT(num_sms % 2 == 0);\n        const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;\n        const int num_channels = num_sms / 2;\n\n        size_t num_bytes = 0;\n        num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2;\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;\n        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;\n        num_bytes = ((num_bytes + 127) / 128) * 128;\n        return num_bytes;\n#else\n        EP_HOST_ASSERT(false and \"NVSHMEM is disable during compilation\");\n#endif\n    }\n};\n\nstruct LowLatencyBuffer {\n    int num_clean_int = 0;\n\n    void* dispatch_rdma_send_buffer = nullptr;\n    void* dispatch_rdma_recv_data_buffer = nullptr;\n    int* dispatch_rdma_recv_count_buffer = nullptr;\n\n    void* combine_rdma_send_buffer = nullptr;\n    void* combine_rdma_recv_data_buffer = nullptr;\n    int* combine_rdma_recv_flag_buffer = nullptr;\n\n    void* combine_rdma_send_buffer_data_start = nullptr;\n    size_t num_bytes_per_combine_msg = 0;\n\n    std::pair<int*, int> clean_meta() {\n        EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);\n        return {dispatch_rdma_recv_count_buffer, num_clean_int};\n    }\n};\n\nstruct LowLatencyLayout {\n    size_t total_bytes = 0;\n    LowLatencyBuffer buffers[2];\n\n    template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>\n    out_ptr_t advance(const in_ptr_t& ptr, size_t count) {\n        return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);\n    }\n\n    LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {\n        const int num_scales = hidden / 128;\n\n        // Dispatch and combine layout:\n        //  - 2 symmetric odd/even send buffer\n        //  - 2 symmetric odd/even receive buffers\n        //  - 2 symmetric odd/even signaling buffers\n\n        // Message sizes\n        // NOTES: you should add a control `int4` for combine messages if you want to do data transformation\n        // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max\n        EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);\n        size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));\n        size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16);\n\n        // Send buffer\n        size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;\n        size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;\n        size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);\n        EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);\n        total_bytes += send_buffer_bytes * 2;\n\n        // Symmetric receive buffers\n        // TODO: optimize memory usages\n        size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;\n        size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;\n        size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);\n        EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);\n        total_bytes += recv_buffer_bytes * 2;\n\n        // Symmetric signaling buffers\n        size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);\n        size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;\n        size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);\n        size_t signaling_buffer_bytes_aligned = align_up<size_t>(signaling_buffer_bytes, 128);\n        total_bytes += signaling_buffer_bytes_aligned * 2;\n\n        // Assign pointers\n        // NOTES: we still leave some space for distinguishing dispatch/combine buffer,\n        // so you may see some parameters are duplicated\n        for (int i = 0; i < 2; ++i) {\n            buffers[i] = {static_cast<int>(signaling_buffer_bytes / sizeof(int)),\n                          advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),\n                          advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),\n                          advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i),\n                          advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),\n                          advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),\n                          advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i),\n                          advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),\n                          num_bytes_per_combine_msg};\n        }\n    }\n};\n\nsize_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {\n    auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;\n    return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;\n}\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/deep_ep.cpp",
    "content": "#include \"deep_ep.hpp\"\n\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDADataType.h>\n#include <cuda_runtime.h>\n#include <pybind11/functional.h>\n#include <torch/python.h>\n\n#include <chrono>\n#include <memory>\n\n#include \"kernels/api.cuh\"\n#include \"kernels/configs.cuh\"\n\nnamespace shared_memory {\nvoid cu_mem_set_access_all(void* ptr, size_t size) {\n    int device_count;\n    CUDA_CHECK(cudaGetDeviceCount(&device_count));\n\n    CUmemAccessDesc access_desc[device_count];\n    for (int idx = 0; idx < device_count; ++idx) {\n        access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n        access_desc[idx].location.id = idx;\n        access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;\n    }\n\n    CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count));\n}\n\nvoid cu_mem_free(void* ptr) {\n    CUmemGenericAllocationHandle handle;\n    CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));\n\n    size_t size = 0;\n    CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));\n\n    CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));\n    CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));\n    CU_CHECK(cuMemRelease(handle));\n}\n\nsize_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {\n    size_t size = (size_raw + granularity - 1) & ~(granularity - 1);\n    if (size == 0)\n        size = granularity;\n    return size;\n}\n\nSharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {}\n\nvoid SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {\n    if (use_fabric) {\n        CUdevice device;\n        CU_CHECK(cuCtxGetDevice(&device));\n\n        CUmemAllocationProp prop = {};\n        prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;\n        prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n        prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;\n        prop.location.id = device;\n\n        size_t granularity = 0;\n        CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));\n\n        size_t size = get_size_align_to_granularity(size_raw, granularity);\n\n        CUmemGenericAllocationHandle handle;\n        CU_CHECK(cuMemCreate(&handle, size, &prop, 0));\n\n        CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0));\n        CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));\n        cu_mem_set_access_all(*ptr, size);\n    } else {\n        CUDA_CHECK(cudaMalloc(ptr, size_raw));\n    }\n}\n\nvoid SharedMemoryAllocator::free(void* ptr) {\n    if (use_fabric) {\n        cu_mem_free(ptr);\n    } else {\n        CUDA_CHECK(cudaFree(ptr));\n    }\n}\n\nvoid SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {\n    size_t size = 0;\n    CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));\n\n    mem_handle->size = size;\n\n    if (use_fabric) {\n        CUmemGenericAllocationHandle handle;\n        CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));\n\n        CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));\n    } else {\n        CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr));\n    }\n}\n\nvoid SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {\n    if (use_fabric) {\n        size_t size = mem_handle->size;\n\n        CUmemGenericAllocationHandle handle;\n        CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC));\n\n        CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0));\n        CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));\n        cu_mem_set_access_all(*ptr, size);\n    } else {\n        CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess));\n    }\n}\n\nvoid SharedMemoryAllocator::close_mem_handle(void* ptr) {\n    if (use_fabric) {\n        cu_mem_free(ptr);\n    } else {\n        CUDA_CHECK(cudaIpcCloseMemHandle(ptr));\n    }\n}\n}  // namespace shared_memory\n\nnamespace deep_ep {\n\nBuffer::Buffer(int rank,\n               int num_ranks,\n               int64_t num_nvl_bytes,\n               int64_t num_rdma_bytes,\n               bool low_latency_mode,\n               bool explicitly_destroy,\n               bool enable_shrink,\n               bool use_fabric)\n    : rank(rank),\n      num_ranks(num_ranks),\n      num_nvl_bytes(num_nvl_bytes),\n      num_rdma_bytes(num_rdma_bytes),\n      enable_shrink(enable_shrink),\n      low_latency_mode(low_latency_mode),\n      explicitly_destroy(explicitly_destroy),\n      comm_stream(at::cuda::getStreamFromPool(true)),\n      shared_memory_allocator(use_fabric) {\n    // Metadata memory\n    int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);\n    int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);\n    int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);\n\n    // Common checks\n    EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, \"Invalid alignment\");\n    EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and\n                   (num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));\n    EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and\n                   (low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));\n    EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));\n    EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);\n    if (num_rdma_bytes > 0)\n        EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode);\n\n    // Get ranks\n    CUDA_CHECK(cudaGetDevice(&device_id));\n    rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;\n    num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);\n#ifdef DISABLE_NVSHMEM\n    EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and \"NVSHMEM is disabled during compilation\");\n#endif\n\n    // Get device info\n    cudaDeviceProp device_prop = {};\n    CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));\n    num_device_sms = device_prop.multiProcessorCount;\n\n    // Number of per-channel bytes cannot be large\n    EP_HOST_ASSERT(ceil_div<int64_t>(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(ceil_div<int64_t>(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());\n\n    if (num_nvl_bytes > 0) {\n        // Local IPC: alloc local memory and set local IPC handles\n        shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank],\n                                       num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes);\n        shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]);\n        buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);\n\n        // Set barrier signals\n        barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);\n        barrier_signal_ptrs_gpu =\n            reinterpret_cast<int**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes);\n\n        // No need to synchronize, will do a full device sync during `sync`\n        CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));\n    }\n\n    // Create 32 MiB workspace\n    CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES));\n    CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));\n\n    // MoE counter\n    CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped));\n    CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast<int*>(moe_recv_counter), 0));\n    *moe_recv_counter = -1;\n\n    // MoE expert-level counter\n    CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped));\n    CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast<int*>(moe_recv_expert_counter), 0));\n    for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i)\n        moe_recv_expert_counter[i] = -1;\n\n    // MoE RDMA-level counter\n    if (num_rdma_ranks > 0) {\n        CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped));\n        CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));\n        *moe_recv_rdma_counter = -1;\n    }\n}\n\nBuffer::~Buffer() noexcept(false) {\n    if (not explicitly_destroy) {\n        destroy();\n    } else if (not destroyed) {\n        printf(\"WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\\n\");\n        fflush(stdout);\n    }\n}\n\nbool Buffer::is_available() const {\n    return available;\n}\n\nbool Buffer::is_internode_available() const {\n    return is_available() and num_ranks > NUM_MAX_NVL_PEERS;\n}\n\nint Buffer::get_num_rdma_ranks() const {\n    return num_rdma_ranks;\n}\n\nint Buffer::get_rdma_rank() const {\n    return rdma_rank;\n}\n\nint Buffer::get_root_rdma_rank(bool global) const {\n    return global ? nvl_rank : 0;\n}\n\nint Buffer::get_local_device_id() const {\n    return device_id;\n}\n\npybind11::bytearray Buffer::get_local_ipc_handle() const {\n    const shared_memory::MemHandle& handle = ipc_handles[nvl_rank];\n    return {reinterpret_cast<const char*>(&handle), sizeof(handle)};\n}\n\npybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {\n#ifndef DISABLE_NVSHMEM\n    EP_HOST_ASSERT(rdma_rank == 0 and \"Only RDMA rank 0 can get NVSHMEM unique ID\");\n    auto unique_id = internode::get_unique_id();\n    return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n#endif\n}\n\ntorch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {\n    torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);\n    auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));\n    auto base_ptr = static_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;\n    auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;\n    return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));\n}\n\ntorch::Stream Buffer::get_comm_stream() const {\n    return comm_stream;\n}\n\nvoid Buffer::destroy() {\n    EP_HOST_ASSERT(not destroyed);\n\n    // Synchronize\n    CUDA_CHECK(cudaDeviceSynchronize());\n\n    if (num_nvl_bytes > 0) {\n        // Barrier\n        intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);\n        CUDA_CHECK(cudaDeviceSynchronize());\n\n        // Close remote IPC\n        if (is_available()) {\n            for (int i = 0; i < num_nvl_ranks; ++i)\n                if (i != nvl_rank)\n                    shared_memory_allocator.close_mem_handle(buffer_ptrs[i]);\n        }\n\n        // Free local buffer and error flag\n        shared_memory_allocator.free(buffer_ptrs[nvl_rank]);\n    }\n\n    // Free NVSHMEM\n#ifndef DISABLE_NVSHMEM\n    if (is_available() and num_rdma_bytes > 0) {\n        CUDA_CHECK(cudaDeviceSynchronize());\n        internode::barrier();\n        internode::free(rdma_buffer_ptr);\n        if (enable_shrink) {\n            internode::free(mask_buffer_ptr);\n            internode::free(sync_buffer_ptr);\n        }\n        internode::finalize();\n    }\n#endif\n\n    // Free workspace and MoE counter\n    CUDA_CHECK(cudaFree(workspace));\n    CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));\n\n    // Free chunked mode staffs\n    CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));\n\n    destroyed = true;\n    available = false;\n}\n\nvoid Buffer::sync(const std::vector<int>& device_ids,\n                  const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles,\n                  const std::optional<pybind11::bytearray>& root_unique_id_opt) {\n    EP_HOST_ASSERT(not is_available());\n\n    // Sync IPC handles\n    if (num_nvl_bytes > 0) {\n        EP_HOST_ASSERT(num_ranks == device_ids.size());\n        EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());\n        for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) {\n            EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());\n            auto handle_str = std::string(all_gathered_handles[offset + i].value());\n            EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE);\n            if (offset + i != rank) {\n                std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE);\n                shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]);\n                barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);\n            } else {\n                EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0);\n            }\n        }\n\n        // Copy all buffer and barrier signal pointers to GPU\n        CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));\n        CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));\n        CUDA_CHECK(cudaDeviceSynchronize());\n    }\n\n    // Sync NVSHMEM handles and allocate memory\n#ifndef DISABLE_NVSHMEM\n    if (num_rdma_bytes > 0) {\n        // Initialize NVSHMEM\n        EP_HOST_ASSERT(root_unique_id_opt.has_value());\n        std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());\n        auto root_unique_id_str = root_unique_id_opt->cast<std::string>();\n        std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());\n        auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;\n        auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;\n        EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));\n        internode::barrier();\n\n        // Allocate\n        rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);\n\n        // Clean buffer (mainly for low-latency mode)\n        CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));\n\n        // Allocate and clean shrink buffer\n        if (enable_shrink) {\n            int num_mask_buffer_bytes = num_ranks * sizeof(int);\n            int num_sync_buffer_bytes = num_ranks * sizeof(int);\n            mask_buffer_ptr = reinterpret_cast<int*>(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES));\n            sync_buffer_ptr = reinterpret_cast<int*>(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES));\n            CUDA_CHECK(cudaMemset(mask_buffer_ptr, 0, num_mask_buffer_bytes));\n            CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes));\n        }\n\n        // Barrier\n        internode::barrier();\n        CUDA_CHECK(cudaDeviceSynchronize());\n    }\n#endif\n\n    // Ready to use\n    available = true;\n}\n\nstd::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>\nBuffer::get_dispatch_layout(\n    const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {\n    EP_HOST_ASSERT(topk_idx.dim() == 2);\n    EP_HOST_ASSERT(topk_idx.is_contiguous());\n    EP_HOST_ASSERT(num_experts > 0);\n\n    // Allocate all tensors on comm stream if set\n    // NOTES: do not allocate tensors upfront!\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    if (allocate_on_comm_stream) {\n        EP_HOST_ASSERT(previous_event.has_value() and async);\n        at::cuda::setCurrentCUDAStream(comm_stream);\n    }\n\n    // Wait previous tasks to be finished\n    if (previous_event.has_value()) {\n        stream_wait(comm_stream, previous_event.value());\n    } else {\n        stream_wait(comm_stream, compute_stream);\n    }\n\n    auto num_tokens = static_cast<int>(topk_idx.size(0)), num_topk = static_cast<int>(topk_idx.size(1));\n    auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n    auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();\n    auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA));\n    auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA));\n    if (is_internode_available())\n        num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n\n    layout::get_dispatch_layout(topk_idx.data_ptr<topk_idx_t>(),\n                                num_tokens_per_rank.data_ptr<int>(),\n                                num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,\n                                num_tokens_per_expert.data_ptr<int>(),\n                                is_token_in_rank.data_ptr<bool>(),\n                                num_tokens,\n                                num_topk,\n                                num_ranks,\n                                num_experts,\n                                comm_stream);\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        event = EventHandle(comm_stream);\n        for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {\n            t.record_stream(comm_stream);\n            if (allocate_on_comm_stream)\n                t.record_stream(compute_stream);\n        }\n        for (auto& to : {num_tokens_per_rdma_rank}) {\n            to.has_value() ? to->record_stream(comm_stream) : void();\n            if (allocate_on_comm_stream)\n                to.has_value() ? to->record_stream(compute_stream) : void();\n        }\n    } else {\n        stream_wait(compute_stream, comm_stream);\n    }\n\n    // Switch back compute stream\n    if (allocate_on_comm_stream)\n        at::cuda::setCurrentCUDAStream(compute_stream);\n\n    return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event};\n}\n\nstd::tuple<torch::Tensor,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::vector<int>,\n           torch::Tensor,\n           torch::Tensor,\n           torch::Tensor,\n           torch::Tensor,\n           torch::Tensor,\n           std::optional<EventHandle>>\nBuffer::intranode_dispatch(const torch::Tensor& x,\n                           const std::optional<torch::Tensor>& x_scales,\n                           const std::optional<torch::Tensor>& topk_idx,\n                           const std::optional<torch::Tensor>& topk_weights,\n                           const std::optional<torch::Tensor>& num_tokens_per_rank,\n                           const torch::Tensor& is_token_in_rank,\n                           const std::optional<torch::Tensor>& num_tokens_per_expert,\n                           int cached_num_recv_tokens,\n                           const std::optional<torch::Tensor>& cached_rank_prefix_matrix,\n                           const std::optional<torch::Tensor>& cached_channel_prefix_matrix,\n                           int expert_alignment,\n                           int num_worst_tokens,\n                           const Config& config,\n                           std::optional<EventHandle>& previous_event,\n                           bool async,\n                           bool allocate_on_comm_stream) {\n    bool cached_mode = cached_rank_prefix_matrix.has_value();\n\n    // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.\n    EP_HOST_ASSERT(config.num_sms % 2 == 0);\n    int num_channels = config.num_sms / 2;\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value());\n        EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value());\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_rank.has_value());\n        EP_HOST_ASSERT(num_tokens_per_expert.has_value());\n    }\n\n    // Type checks\n    EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool);\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32);\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);\n    }\n\n    // Shape and contiguous checks\n    EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());\n    EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);\n    EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());\n    EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks);\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous());\n        EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks);\n        EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous());\n        EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels);\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());\n        EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);\n        EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);\n        EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());\n        EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);\n    }\n\n    auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));\n    auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;\n\n    // Top-k checks\n    int num_topk = 0;\n    topk_idx_t* topk_idx_ptr = nullptr;\n    float* topk_weights_ptr = nullptr;\n    EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());\n    if (topk_idx.has_value()) {\n        num_topk = static_cast<int>(topk_idx->size(1));\n        EP_HOST_ASSERT(num_experts > 0);\n        EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());\n        EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());\n        EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));\n        EP_HOST_ASSERT(num_topk == topk_weights->size(1));\n        EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);\n        topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();\n        topk_weights_ptr = topk_weights->data_ptr<float>();\n    }\n\n    // FP8 scales checks\n    float* x_scales_ptr = nullptr;\n    int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;\n    if (x_scales.has_value()) {\n        EP_HOST_ASSERT(x.element_size() == 1);\n        EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);\n        EP_HOST_ASSERT(x_scales->dim() == 2);\n        EP_HOST_ASSERT(x_scales->size(0) == num_tokens);\n        num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));\n        x_scales_ptr = static_cast<float*>(x_scales->data_ptr());\n        scale_token_stride = static_cast<int>(x_scales->stride(0));\n        scale_hidden_stride = static_cast<int>(x_scales->stride(1));\n    }\n\n    // Allocate all tensors on comm stream if set\n    // NOTES: do not allocate tensors upfront!\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    if (allocate_on_comm_stream) {\n        EP_HOST_ASSERT(previous_event.has_value() and async);\n        at::cuda::setCurrentCUDAStream(comm_stream);\n    }\n\n    // Wait previous tasks to be finished\n    if (previous_event.has_value()) {\n        stream_wait(comm_stream, previous_event.value());\n    } else {\n        stream_wait(comm_stream, compute_stream);\n    }\n\n    // Create handles (only return for non-cached mode)\n    int num_recv_tokens = -1;\n    auto rank_prefix_matrix = torch::Tensor();\n    auto channel_prefix_matrix = torch::Tensor();\n    std::vector<int> num_recv_tokens_per_expert_list;\n\n    // Barrier or send sizes\n    // To clean: channel start/end offset, head and tail\n    int num_memset_int = num_channels * num_ranks * 4;\n    if (cached_mode) {\n        num_recv_tokens = cached_num_recv_tokens;\n        rank_prefix_matrix = cached_rank_prefix_matrix.value();\n        channel_prefix_matrix = cached_channel_prefix_matrix.value();\n\n        // Copy rank prefix matrix and clean flags\n        intranode::cached_notify_dispatch(\n            rank_prefix_matrix.data_ptr<int>(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream);\n    } else {\n        rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n        channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n\n        // Send sizes\n        // Meta information:\n        //  - Size prefix by ranks, shaped as `[num_ranks, num_ranks]`\n        //  - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]`\n        // NOTES: no more token dropping in this version\n        *moe_recv_counter = -1;\n        for (int i = 0; i < num_local_experts; ++i)\n            moe_recv_expert_counter[i] = -1;\n        EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes);\n        intranode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(),\n                                   moe_recv_counter_mapped,\n                                   num_ranks,\n                                   num_tokens_per_expert->data_ptr<int>(),\n                                   moe_recv_expert_counter_mapped,\n                                   num_experts,\n                                   num_tokens,\n                                   is_token_in_rank.data_ptr<bool>(),\n                                   channel_prefix_matrix.data_ptr<int>(),\n                                   rank_prefix_matrix.data_ptr<int>(),\n                                   num_memset_int,\n                                   expert_alignment,\n                                   buffer_ptrs_gpu,\n                                   barrier_signal_ptrs_gpu,\n                                   rank,\n                                   comm_stream,\n                                   num_channels);\n\n        if (num_worst_tokens > 0) {\n            // No CPU sync, just allocate the worst case\n            num_recv_tokens = num_worst_tokens;\n\n            // Must be forward with top-k stuffs\n            EP_HOST_ASSERT(topk_idx.has_value());\n            EP_HOST_ASSERT(topk_weights.has_value());\n        } else {\n            // Synchronize total received tokens and tokens per expert\n            auto start_time = std::chrono::high_resolution_clock::now();\n            while (true) {\n                // Read total count\n                num_recv_tokens = static_cast<int>(*moe_recv_counter);\n\n                // Read per-expert count\n                bool ready = (num_recv_tokens >= 0);\n                for (int i = 0; i < num_local_experts and ready; ++i)\n                    ready &= moe_recv_expert_counter[i] >= 0;\n\n                if (ready)\n                    break;\n\n                // Timeout check\n                if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >\n                    NUM_CPU_TIMEOUT_SECS)\n                    throw std::runtime_error(\"DeepEP error: CPU recv timeout\");\n            }\n            num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);\n        }\n    }\n\n    // Allocate new tensors\n    auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());\n    auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA));\n    auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(),\n         recv_x_scales = std::optional<torch::Tensor>();\n    auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n    auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n\n    // Assign pointers\n    topk_idx_t* recv_topk_idx_ptr = nullptr;\n    float* recv_topk_weights_ptr = nullptr;\n    float* recv_x_scales_ptr = nullptr;\n    if (topk_idx.has_value()) {\n        recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());\n        recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());\n        recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();\n        recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();\n    }\n    if (x_scales.has_value()) {\n        recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options())\n                                             : torch::empty({num_recv_tokens, num_scales}, x_scales->options());\n        recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());\n    }\n\n    // Dispatch\n    EP_HOST_ASSERT(\n        num_ranks * num_ranks * sizeof(int) +                                                                     // Size prefix matrix\n            num_channels * num_ranks * sizeof(int) +                                                              // Channel start offset\n            num_channels * num_ranks * sizeof(int) +                                                              // Channel end offset\n            num_channels * num_ranks * sizeof(int) * 2 +                                                          // Queue head and tail\n            num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() +  // Data buffer\n            num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) +                     // Source index buffer\n            num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) +   // Top-k index buffer\n            num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) +        // Top-k weight buffer\n            num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales        // FP8 scale buffer\n        <= num_nvl_bytes);\n    intranode::dispatch(recv_x.data_ptr(),\n                        recv_x_scales_ptr,\n                        recv_src_idx.data_ptr<int>(),\n                        recv_topk_idx_ptr,\n                        recv_topk_weights_ptr,\n                        recv_channel_prefix_matrix.data_ptr<int>(),\n                        send_head.data_ptr<int>(),\n                        x.data_ptr(),\n                        x_scales_ptr,\n                        topk_idx_ptr,\n                        topk_weights_ptr,\n                        is_token_in_rank.data_ptr<bool>(),\n                        channel_prefix_matrix.data_ptr<int>(),\n                        num_tokens,\n                        num_worst_tokens,\n                        static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),\n                        num_topk,\n                        num_experts,\n                        num_scales,\n                        scale_token_stride,\n                        scale_hidden_stride,\n                        buffer_ptrs_gpu,\n                        rank,\n                        num_ranks,\n                        comm_stream,\n                        config.num_sms,\n                        config.num_max_nvl_chunked_send_tokens,\n                        config.num_max_nvl_chunked_recv_tokens);\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        event = EventHandle(comm_stream);\n        for (auto& t : {x,\n                        is_token_in_rank,\n                        rank_prefix_matrix,\n                        channel_prefix_matrix,\n                        recv_x,\n                        recv_src_idx,\n                        recv_channel_prefix_matrix,\n                        send_head}) {\n            t.record_stream(comm_stream);\n            if (allocate_on_comm_stream)\n                t.record_stream(compute_stream);\n        }\n        for (auto& to : {x_scales,\n                         topk_idx,\n                         topk_weights,\n                         num_tokens_per_rank,\n                         num_tokens_per_expert,\n                         cached_channel_prefix_matrix,\n                         cached_rank_prefix_matrix,\n                         recv_topk_idx,\n                         recv_topk_weights,\n                         recv_x_scales}) {\n            to.has_value() ? to->record_stream(comm_stream) : void();\n            if (allocate_on_comm_stream)\n                to.has_value() ? to->record_stream(compute_stream) : void();\n        }\n    } else {\n        stream_wait(compute_stream, comm_stream);\n    }\n\n    // Switch back compute stream\n    if (allocate_on_comm_stream)\n        at::cuda::setCurrentCUDAStream(compute_stream);\n\n    // Return values\n    return {recv_x,\n            recv_x_scales,\n            recv_topk_idx,\n            recv_topk_weights,\n            num_recv_tokens_per_expert_list,\n            rank_prefix_matrix,\n            channel_prefix_matrix,\n            recv_channel_prefix_matrix,\n            recv_src_idx,\n            send_head,\n            event};\n}\n\nstd::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> Buffer::intranode_combine(\n    const torch::Tensor& x,\n    const std::optional<torch::Tensor>& topk_weights,\n    const std::optional<torch::Tensor>& bias_0,\n    const std::optional<torch::Tensor>& bias_1,\n    const torch::Tensor& src_idx,\n    const torch::Tensor& rank_prefix_matrix,\n    const torch::Tensor& channel_prefix_matrix,\n    const torch::Tensor& send_head,\n    const Config& config,\n    std::optional<EventHandle>& previous_event,\n    bool async,\n    bool allocate_on_comm_stream) {\n    EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());\n    EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and\n                   rank_prefix_matrix.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and\n                   channel_prefix_matrix.scalar_type() == torch::kInt32);\n\n    // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.\n    EP_HOST_ASSERT(config.num_sms % 2 == 0);\n    int num_channels = config.num_sms / 2;\n\n    auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));\n    auto num_recv_tokens = static_cast<int>(send_head.size(0));\n    EP_HOST_ASSERT(src_idx.size(0) == num_tokens);\n    EP_HOST_ASSERT(send_head.size(1) == num_ranks);\n    EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks);\n    EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels);\n    EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);\n\n    // Allocate all tensors on comm stream if set\n    // NOTES: do not allocate tensors upfront!\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    if (allocate_on_comm_stream) {\n        EP_HOST_ASSERT(previous_event.has_value() and async);\n        at::cuda::setCurrentCUDAStream(comm_stream);\n    }\n\n    // Wait previous tasks to be finished\n    if (previous_event.has_value()) {\n        stream_wait(comm_stream, previous_event.value());\n    } else {\n        stream_wait(comm_stream, compute_stream);\n    }\n\n    int num_topk = 0;\n    auto recv_topk_weights = std::optional<torch::Tensor>();\n    float* topk_weights_ptr = nullptr;\n    float* recv_topk_weights_ptr = nullptr;\n    if (topk_weights.has_value()) {\n        EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());\n        EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);\n        EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);\n        num_topk = static_cast<int>(topk_weights->size(1));\n        topk_weights_ptr = topk_weights->data_ptr<float>();\n        recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());\n        recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();\n    }\n\n    // Launch barrier and reset queue head and tail\n    EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);\n    intranode::cached_notify_combine(buffer_ptrs_gpu,\n                                     send_head.data_ptr<int>(),\n                                     num_channels,\n                                     num_recv_tokens,\n                                     num_channels * num_ranks * 2,\n                                     barrier_signal_ptrs_gpu,\n                                     rank,\n                                     num_ranks,\n                                     comm_stream);\n\n    // Assign bias pointers\n    auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});\n    void* bias_ptrs[2] = {nullptr, nullptr};\n    for (int i = 0; i < 2; ++i)\n        if (bias_opts[i].has_value()) {\n            auto bias = bias_opts[i].value();\n            EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());\n            EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());\n            EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);\n            bias_ptrs[i] = bias.data_ptr();\n        }\n\n    // Combine data\n    auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());\n    EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 +  // Queue head and tail\n                       num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() +  // Data buffer\n                       num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) +             // Source index buffer\n                       num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float)  // Top-k weight buffer\n                   <= num_nvl_bytes);\n    intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),\n                       recv_x.data_ptr(),\n                       recv_topk_weights_ptr,\n                       x.data_ptr(),\n                       topk_weights_ptr,\n                       bias_ptrs[0],\n                       bias_ptrs[1],\n                       src_idx.data_ptr<int>(),\n                       rank_prefix_matrix.data_ptr<int>(),\n                       channel_prefix_matrix.data_ptr<int>(),\n                       send_head.data_ptr<int>(),\n                       num_tokens,\n                       num_recv_tokens,\n                       hidden,\n                       num_topk,\n                       buffer_ptrs_gpu,\n                       rank,\n                       num_ranks,\n                       comm_stream,\n                       config.num_sms,\n                       config.num_max_nvl_chunked_send_tokens,\n                       config.num_max_nvl_chunked_recv_tokens);\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        event = EventHandle(comm_stream);\n        for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {\n            t.record_stream(comm_stream);\n            if (allocate_on_comm_stream)\n                t.record_stream(compute_stream);\n        }\n        for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) {\n            to.has_value() ? to->record_stream(comm_stream) : void();\n            if (allocate_on_comm_stream)\n                to.has_value() ? to->record_stream(compute_stream) : void();\n        }\n    } else {\n        stream_wait(compute_stream, comm_stream);\n    }\n\n    // Switch back compute stream\n    if (allocate_on_comm_stream)\n        at::cuda::setCurrentCUDAStream(compute_stream);\n\n    return {recv_x, recv_topk_weights, event};\n}\n\nstd::tuple<torch::Tensor,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::vector<int>,\n           torch::Tensor,\n           torch::Tensor,\n           std::optional<torch::Tensor>,\n           torch::Tensor,\n           std::optional<torch::Tensor>,\n           torch::Tensor,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::optional<torch::Tensor>,\n           std::optional<EventHandle>>\nBuffer::internode_dispatch(const torch::Tensor& x,\n                           const std::optional<torch::Tensor>& x_scales,\n                           const std::optional<torch::Tensor>& topk_idx,\n                           const std::optional<torch::Tensor>& topk_weights,\n                           const std::optional<torch::Tensor>& num_tokens_per_rank,\n                           const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,\n                           const torch::Tensor& is_token_in_rank,\n                           const std::optional<torch::Tensor>& num_tokens_per_expert,\n                           int cached_num_recv_tokens,\n                           int cached_num_rdma_recv_tokens,\n                           const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix,\n                           const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,\n                           const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,\n                           const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,\n                           int expert_alignment,\n                           int num_worst_tokens,\n                           const Config& config,\n                           std::optional<EventHandle>& previous_event,\n                           bool async,\n                           bool allocate_on_comm_stream) {\n#ifndef DISABLE_NVSHMEM\n    // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long.\n    // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL\n    // unless we release GIL here.\n    pybind11::gil_scoped_release release;\n\n    const int num_channels = config.num_sms / 2;\n    EP_HOST_ASSERT(config.num_sms % 2 == 0);\n    EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);\n\n    bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value());\n        EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value());\n        EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value());\n        EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value());\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_rank.has_value());\n        EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value());\n        EP_HOST_ASSERT(num_tokens_per_expert.has_value());\n    }\n\n    // Type checks\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32);\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32);\n        EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);\n    }\n\n    // Shape and contiguous checks\n    EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());\n    EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);\n    if (cached_mode) {\n        EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous());\n        EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and\n                       cached_rdma_channel_prefix_matrix->size(1) == num_channels);\n        EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous());\n        EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);\n        EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous());\n        EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and\n                       cached_gbl_channel_prefix_matrix->size(1) == num_channels);\n        EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous());\n        EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);\n    } else {\n        EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());\n        EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous());\n        EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());\n        EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);\n        EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks);\n        EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);\n        EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);\n    }\n\n    auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)),\n         hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));\n    auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;\n\n    // Top-k checks\n    int num_topk = 0;\n    topk_idx_t* topk_idx_ptr = nullptr;\n    float* topk_weights_ptr = nullptr;\n    EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());\n    if (topk_idx.has_value()) {\n        num_topk = static_cast<int>(topk_idx->size(1));\n        EP_HOST_ASSERT(num_experts > 0);\n        EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());\n        EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());\n        EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));\n        EP_HOST_ASSERT(num_topk == topk_weights->size(1));\n        EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);\n        topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();\n        topk_weights_ptr = topk_weights->data_ptr<float>();\n    }\n\n    // FP8 scales checks\n    float* x_scales_ptr = nullptr;\n    int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;\n    if (x_scales.has_value()) {\n        EP_HOST_ASSERT(x.element_size() == 1);\n        EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);\n        EP_HOST_ASSERT(x_scales->dim() == 2);\n        EP_HOST_ASSERT(x_scales->size(0) == num_tokens);\n        num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));\n        x_scales_ptr = static_cast<float*>(x_scales->data_ptr());\n        scale_token_stride = static_cast<int>(x_scales->stride(0));\n        scale_hidden_stride = static_cast<int>(x_scales->stride(1));\n    }\n\n    // Allocate all tensors on comm stream if set\n    // NOTES: do not allocate tensors upfront!\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    if (allocate_on_comm_stream) {\n        EP_HOST_ASSERT(previous_event.has_value() and async);\n        at::cuda::setCurrentCUDAStream(comm_stream);\n    }\n\n    // Wait previous tasks to be finished\n    if (previous_event.has_value()) {\n        stream_wait(comm_stream, previous_event.value());\n    } else {\n        stream_wait(comm_stream, compute_stream);\n    }\n\n    // Create handles (only return for non-cached mode)\n    int num_recv_tokens = -1, num_rdma_recv_tokens = -1;\n    auto rdma_channel_prefix_matrix = torch::Tensor();\n    auto recv_rdma_rank_prefix_sum = torch::Tensor();\n    auto gbl_channel_prefix_matrix = torch::Tensor();\n    auto recv_gbl_rank_prefix_sum = torch::Tensor();\n    std::vector<int> num_recv_tokens_per_expert_list;\n\n    // Barrier or send sizes\n    if (cached_mode) {\n        num_recv_tokens = cached_num_recv_tokens;\n        num_rdma_recv_tokens = cached_num_rdma_recv_tokens;\n        rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();\n        recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value();\n        gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();\n        recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();\n\n        // Just a barrier and clean flags\n        internode::cached_notify(hidden_int4,\n                                 num_scales,\n                                 num_topk,\n                                 num_topk,\n                                 num_ranks,\n                                 num_channels,\n                                 0,\n                                 nullptr,\n                                 nullptr,\n                                 nullptr,\n                                 nullptr,\n                                 rdma_buffer_ptr,\n                                 config.num_max_rdma_chunked_recv_tokens,\n                                 buffer_ptrs_gpu,\n                                 config.num_max_nvl_chunked_recv_tokens,\n                                 barrier_signal_ptrs_gpu,\n                                 rank,\n                                 comm_stream,\n                                 config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),\n                                 num_nvl_bytes,\n                                 true,\n                                 low_latency_mode);\n    } else {\n        rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n        recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n        gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n        recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n\n        // Send sizes\n        *moe_recv_counter = -1, *moe_recv_rdma_counter = -1;\n        for (int i = 0; i < num_local_experts; ++i)\n            moe_recv_expert_counter[i] = -1;\n        internode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(),\n                                   moe_recv_counter_mapped,\n                                   num_ranks,\n                                   num_tokens_per_rdma_rank->data_ptr<int>(),\n                                   moe_recv_rdma_counter_mapped,\n                                   num_tokens_per_expert->data_ptr<int>(),\n                                   moe_recv_expert_counter_mapped,\n                                   num_experts,\n                                   is_token_in_rank.data_ptr<bool>(),\n                                   num_tokens,\n                                   num_worst_tokens,\n                                   num_channels,\n                                   hidden_int4,\n                                   num_scales,\n                                   num_topk,\n                                   expert_alignment,\n                                   rdma_channel_prefix_matrix.data_ptr<int>(),\n                                   recv_rdma_rank_prefix_sum.data_ptr<int>(),\n                                   gbl_channel_prefix_matrix.data_ptr<int>(),\n                                   recv_gbl_rank_prefix_sum.data_ptr<int>(),\n                                   rdma_buffer_ptr,\n                                   config.num_max_rdma_chunked_recv_tokens,\n                                   buffer_ptrs_gpu,\n                                   config.num_max_nvl_chunked_recv_tokens,\n                                   barrier_signal_ptrs_gpu,\n                                   rank,\n                                   comm_stream,\n                                   config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),\n                                   num_nvl_bytes,\n                                   low_latency_mode);\n\n        // Synchronize total received tokens and tokens per expert\n        if (num_worst_tokens > 0) {\n            num_recv_tokens = num_worst_tokens;\n            num_rdma_recv_tokens = num_worst_tokens;\n        } else {\n            auto start_time = std::chrono::high_resolution_clock::now();\n            while (true) {\n                // Read total count\n                num_recv_tokens = static_cast<int>(*moe_recv_counter);\n                num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);\n\n                // Read per-expert count\n                bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);\n                for (int i = 0; i < num_local_experts and ready; ++i)\n                    ready &= moe_recv_expert_counter[i] >= 0;\n\n                if (ready)\n                    break;\n\n                // Timeout check\n                if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >\n                    NUM_CPU_TIMEOUT_SECS) {\n                    printf(\"Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\\n\", rank, num_recv_tokens, num_rdma_recv_tokens);\n                    for (int i = 0; i < num_local_experts; ++i)\n                        printf(\"moe_recv_expert_counter[%d]: %d\\n\", i, moe_recv_expert_counter[i]);\n                    throw std::runtime_error(\"DeepEP error: timeout (dispatch CPU)\");\n                }\n            }\n            num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);\n        }\n    }\n\n    // Allocate new tensors\n    auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());\n    auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(),\n         recv_x_scales = std::optional<torch::Tensor>();\n    auto recv_src_meta = std::optional<torch::Tensor>();\n    auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();\n    auto recv_gbl_channel_prefix_matrix = std::optional<torch::Tensor>();\n    auto send_rdma_head = std::optional<torch::Tensor>();\n    auto send_nvl_head = std::optional<torch::Tensor>();\n    if (not cached_mode) {\n        recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA));\n        recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n        recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));\n        send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));\n        send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA));\n    }\n\n    // Assign pointers\n    topk_idx_t* recv_topk_idx_ptr = nullptr;\n    float* recv_topk_weights_ptr = nullptr;\n    float* recv_x_scales_ptr = nullptr;\n    if (topk_idx.has_value()) {\n        recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());\n        recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());\n        recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();\n        recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();\n    }\n    if (x_scales.has_value()) {\n        recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options())\n                                             : torch::empty({num_recv_tokens, num_scales}, x_scales->options());\n        recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());\n    }\n\n    // Launch data dispatch\n    // NOTES: the buffer size checks are moved into the `.cu` file\n    internode::dispatch(recv_x.data_ptr(),\n                        recv_x_scales_ptr,\n                        recv_topk_idx_ptr,\n                        recv_topk_weights_ptr,\n                        cached_mode ? nullptr : recv_src_meta->data_ptr(),\n                        x.data_ptr(),\n                        x_scales_ptr,\n                        topk_idx_ptr,\n                        topk_weights_ptr,\n                        cached_mode ? nullptr : send_rdma_head->data_ptr<int>(),\n                        cached_mode ? nullptr : send_nvl_head->data_ptr<int>(),\n                        cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr<int>(),\n                        cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),\n                        rdma_channel_prefix_matrix.data_ptr<int>(),\n                        recv_rdma_rank_prefix_sum.data_ptr<int>(),\n                        gbl_channel_prefix_matrix.data_ptr<int>(),\n                        recv_gbl_rank_prefix_sum.data_ptr<int>(),\n                        is_token_in_rank.data_ptr<bool>(),\n                        num_tokens,\n                        num_worst_tokens,\n                        hidden_int4,\n                        num_scales,\n                        num_topk,\n                        num_experts,\n                        scale_token_stride,\n                        scale_hidden_stride,\n                        rdma_buffer_ptr,\n                        config.num_max_rdma_chunked_send_tokens,\n                        config.num_max_rdma_chunked_recv_tokens,\n                        buffer_ptrs_gpu,\n                        config.num_max_nvl_chunked_send_tokens,\n                        config.num_max_nvl_chunked_recv_tokens,\n                        rank,\n                        num_ranks,\n                        cached_mode,\n                        comm_stream,\n                        num_channels,\n                        low_latency_mode);\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        event = EventHandle(comm_stream);\n        for (auto& t : {x,\n                        is_token_in_rank,\n                        recv_x,\n                        rdma_channel_prefix_matrix,\n                        recv_rdma_rank_prefix_sum,\n                        gbl_channel_prefix_matrix,\n                        recv_gbl_rank_prefix_sum}) {\n            t.record_stream(comm_stream);\n            if (allocate_on_comm_stream)\n                t.record_stream(compute_stream);\n        }\n        for (auto& to : {x_scales,\n                         topk_idx,\n                         topk_weights,\n                         num_tokens_per_rank,\n                         num_tokens_per_rdma_rank,\n                         num_tokens_per_expert,\n                         cached_rdma_channel_prefix_matrix,\n                         cached_recv_rdma_rank_prefix_sum,\n                         cached_gbl_channel_prefix_matrix,\n                         cached_recv_gbl_rank_prefix_sum,\n                         recv_topk_idx,\n                         recv_topk_weights,\n                         recv_x_scales,\n                         recv_rdma_channel_prefix_matrix,\n                         recv_gbl_channel_prefix_matrix,\n                         send_rdma_head,\n                         send_nvl_head,\n                         recv_src_meta}) {\n            to.has_value() ? to->record_stream(comm_stream) : void();\n            if (allocate_on_comm_stream)\n                to.has_value() ? to->record_stream(compute_stream) : void();\n        }\n    } else {\n        stream_wait(compute_stream, comm_stream);\n    }\n\n    // Switch back compute stream\n    if (allocate_on_comm_stream)\n        at::cuda::setCurrentCUDAStream(compute_stream);\n\n    // Return values\n    return {recv_x,\n            recv_x_scales,\n            recv_topk_idx,\n            recv_topk_weights,\n            num_recv_tokens_per_expert_list,\n            rdma_channel_prefix_matrix,\n            gbl_channel_prefix_matrix,\n            recv_rdma_channel_prefix_matrix,\n            recv_rdma_rank_prefix_sum,\n            recv_gbl_channel_prefix_matrix,\n            recv_gbl_rank_prefix_sum,\n            recv_src_meta,\n            send_rdma_head,\n            send_nvl_head,\n            event};\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n    return {};\n#endif\n}\n\nstd::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> Buffer::internode_combine(\n    const torch::Tensor& x,\n    const std::optional<torch::Tensor>& topk_weights,\n    const std::optional<torch::Tensor>& bias_0,\n    const std::optional<torch::Tensor>& bias_1,\n    const torch::Tensor& src_meta,\n    const torch::Tensor& is_combined_token_in_rank,\n    const torch::Tensor& rdma_channel_prefix_matrix,\n    const torch::Tensor& rdma_rank_prefix_sum,\n    const torch::Tensor& gbl_channel_prefix_matrix,\n    const torch::Tensor& combined_rdma_head,\n    const torch::Tensor& combined_nvl_head,\n    const Config& config,\n    std::optional<EventHandle>& previous_event,\n    bool async,\n    bool allocate_on_comm_stream) {\n#ifndef DISABLE_NVSHMEM\n    const int num_channels = config.num_sms / 2;\n    EP_HOST_ASSERT(config.num_sms % 2 == 0);\n\n    // Shape and contiguous checks\n    EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());\n    EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte);\n    EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and\n                   is_combined_token_in_rank.scalar_type() == torch::kBool);\n    EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and\n                   rdma_channel_prefix_matrix.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and\n                   rdma_rank_prefix_sum.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and\n                   gbl_channel_prefix_matrix.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and\n                   combined_rdma_head.scalar_type() == torch::kInt32);\n    EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32);\n\n    auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)),\n         hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));\n    auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));\n    EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);\n    EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes());\n    EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);\n    EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels);\n    EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);\n    EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels);\n    EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and\n                   combined_rdma_head.size(1) == num_rdma_ranks);\n    EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);\n\n    // Allocate all tensors on comm stream if set\n    // NOTES: do not allocate tensors upfront!\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    if (allocate_on_comm_stream) {\n        EP_HOST_ASSERT(previous_event.has_value() and async);\n        at::cuda::setCurrentCUDAStream(comm_stream);\n    }\n\n    // Wait previous tasks to be finished\n    if (previous_event.has_value()) {\n        stream_wait(comm_stream, previous_event.value());\n    } else {\n        stream_wait(comm_stream, compute_stream);\n    }\n\n    // Top-k checks\n    int num_topk = 0;\n    auto combined_topk_weights = std::optional<torch::Tensor>();\n    float* topk_weights_ptr = nullptr;\n    float* combined_topk_weights_ptr = nullptr;\n    if (topk_weights.has_value()) {\n        EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());\n        EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);\n        EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);\n        num_topk = static_cast<int>(topk_weights->size(1));\n        topk_weights_ptr = topk_weights->data_ptr<float>();\n        combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options());\n        combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();\n    }\n\n    // Extra check for avoid-dead-lock design\n    EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);\n    EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);\n\n    // Launch barrier and reset queue head and tail\n    internode::cached_notify(hidden_int4,\n                             0,\n                             0,\n                             num_topk,\n                             num_ranks,\n                             num_channels,\n                             num_combined_tokens,\n                             combined_rdma_head.data_ptr<int>(),\n                             rdma_channel_prefix_matrix.data_ptr<int>(),\n                             rdma_rank_prefix_sum.data_ptr<int>(),\n                             combined_nvl_head.data_ptr<int>(),\n                             rdma_buffer_ptr,\n                             config.num_max_rdma_chunked_recv_tokens,\n                             buffer_ptrs_gpu,\n                             config.num_max_nvl_chunked_recv_tokens,\n                             barrier_signal_ptrs_gpu,\n                             rank,\n                             comm_stream,\n                             config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),\n                             num_nvl_bytes,\n                             false,\n                             low_latency_mode);\n\n    // Assign bias pointers\n    auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});\n    void* bias_ptrs[2] = {nullptr, nullptr};\n    for (int i = 0; i < 2; ++i)\n        if (bias_opts[i].has_value()) {\n            auto bias = bias_opts[i].value();\n            EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());\n            EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());\n            EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);\n            bias_ptrs[i] = bias.data_ptr();\n        }\n\n    // Launch data combine\n    auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());\n    internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),\n                       combined_x.data_ptr(),\n                       combined_topk_weights_ptr,\n                       is_combined_token_in_rank.data_ptr<bool>(),\n                       x.data_ptr(),\n                       topk_weights_ptr,\n                       bias_ptrs[0],\n                       bias_ptrs[1],\n                       combined_rdma_head.data_ptr<int>(),\n                       combined_nvl_head.data_ptr<int>(),\n                       src_meta.data_ptr(),\n                       rdma_channel_prefix_matrix.data_ptr<int>(),\n                       rdma_rank_prefix_sum.data_ptr<int>(),\n                       gbl_channel_prefix_matrix.data_ptr<int>(),\n                       num_tokens,\n                       num_combined_tokens,\n                       hidden,\n                       num_topk,\n                       rdma_buffer_ptr,\n                       config.num_max_rdma_chunked_send_tokens,\n                       config.num_max_rdma_chunked_recv_tokens,\n                       buffer_ptrs_gpu,\n                       config.num_max_nvl_chunked_send_tokens,\n                       config.num_max_nvl_chunked_recv_tokens,\n                       rank,\n                       num_ranks,\n                       comm_stream,\n                       num_channels,\n                       low_latency_mode);\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        event = EventHandle(comm_stream);\n        for (auto& t : {x,\n                        src_meta,\n                        is_combined_token_in_rank,\n                        rdma_channel_prefix_matrix,\n                        rdma_rank_prefix_sum,\n                        gbl_channel_prefix_matrix,\n                        combined_x,\n                        combined_rdma_head,\n                        combined_nvl_head}) {\n            t.record_stream(comm_stream);\n            if (allocate_on_comm_stream)\n                t.record_stream(compute_stream);\n        }\n        for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) {\n            to.has_value() ? to->record_stream(comm_stream) : void();\n            if (allocate_on_comm_stream)\n                to.has_value() ? to->record_stream(compute_stream) : void();\n        }\n    } else {\n        stream_wait(compute_stream, comm_stream);\n    }\n\n    // Switch back compute stream\n    if (allocate_on_comm_stream)\n        at::cuda::setCurrentCUDAStream(compute_stream);\n\n    // Return values\n    return {combined_x, combined_topk_weights, event};\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n    return {};\n#endif\n}\n\nvoid Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {\n#ifndef DISABLE_NVSHMEM\n    EP_HOST_ASSERT(low_latency_mode);\n\n    auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);\n    auto clean_meta_0 = layout.buffers[0].clean_meta();\n    auto clean_meta_1 = layout.buffers[1].clean_meta();\n\n    auto check_boundary = [=](void* ptr, size_t num_bytes) {\n        auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);\n        EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes);\n    };\n    check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int));\n    check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int));\n\n    internode_ll::clean_low_latency_buffer(clean_meta_0.first,\n                                           clean_meta_0.second,\n                                           clean_meta_1.first,\n                                           clean_meta_1.second,\n                                           rank,\n                                           num_ranks,\n                                           mask_buffer_ptr,\n                                           sync_buffer_ptr,\n                                           at::cuda::getCurrentCUDAStream());\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n#endif\n}\n\nstd::tuple<torch::Tensor,\n           std::optional<torch::Tensor>,\n           torch::Tensor,\n           torch::Tensor,\n           torch::Tensor,\n           std::optional<EventHandle>,\n           std::optional<std::function<void()>>>\nBuffer::low_latency_dispatch(const torch::Tensor& x,\n                             const torch::Tensor& topk_idx,\n                             const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,\n                             const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,\n                             int num_max_dispatch_tokens_per_rank,\n                             int num_experts,\n                             bool use_fp8,\n                             bool round_scale,\n                             bool use_ue8m0,\n                             bool async,\n                             bool return_recv_hook) {\n#ifndef DISABLE_NVSHMEM\n    EP_HOST_ASSERT(low_latency_mode);\n\n    // Tensor checks\n    // By default using `ptp128c` FP8 cast\n    EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);\n    EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);\n    EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());\n    EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);\n    EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);\n    EP_HOST_ASSERT(num_experts % num_ranks == 0);\n\n    // Diagnosis tensors\n    if (cumulative_local_expert_recv_stats.has_value()) {\n        EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);\n        EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());\n        EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);\n    }\n    if (dispatch_wait_recv_cost_stats.has_value()) {\n        EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);\n        EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());\n        EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);\n    }\n\n    auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));\n    auto num_topk = static_cast<int>(topk_idx.size(1));\n    auto num_local_experts = num_experts / num_ranks;\n\n    // Buffer control\n    LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);\n    EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);\n    auto buffer = layout.buffers[low_latency_buffer_idx];\n    auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];\n\n    // Wait previous tasks to be finished\n    // NOTES: the hook mode will always use the default stream\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    auto launch_stream = return_recv_hook ? compute_stream : comm_stream;\n    EP_HOST_ASSERT(not(async and return_recv_hook));\n    if (not return_recv_hook)\n        stream_wait(launch_stream, compute_stream);\n\n    // Allocate packed tensors\n    auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},\n                                      x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));\n    auto packed_recv_src_info =\n        torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));\n    auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));\n    auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));\n\n    // Allocate column-majored scales\n    auto packed_recv_x_scales = std::optional<torch::Tensor>();\n    void* packed_recv_x_scales_ptr = nullptr;\n    EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and \"TMA requires the number of tokens to be multiple of 4\");\n\n    if (use_fp8) {\n        // TODO: support unaligned cases\n        EP_HOST_ASSERT(hidden % 512 == 0);\n        if (not use_ue8m0) {\n            packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},\n                                                torch::dtype(torch::kFloat32).device(torch::kCUDA));\n        } else {\n            EP_HOST_ASSERT(round_scale);\n            packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},\n                                                torch::dtype(torch::kInt).device(torch::kCUDA));\n        }\n        packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);\n        packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();\n    }\n\n    // Kernel launch\n    auto next_clean_meta = next_buffer.clean_meta();\n    auto launcher = [=](int phases) {\n        internode_ll::dispatch(\n            packed_recv_x.data_ptr(),\n            packed_recv_x_scales_ptr,\n            packed_recv_src_info.data_ptr<int>(),\n            packed_recv_layout_range.data_ptr<int64_t>(),\n            packed_recv_count.data_ptr<int>(),\n            mask_buffer_ptr,\n            cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,\n            dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,\n            buffer.dispatch_rdma_recv_data_buffer,\n            buffer.dispatch_rdma_recv_count_buffer,\n            buffer.dispatch_rdma_send_buffer,\n            x.data_ptr(),\n            topk_idx.data_ptr<topk_idx_t>(),\n            next_clean_meta.first,\n            next_clean_meta.second,\n            num_tokens,\n            hidden,\n            num_max_dispatch_tokens_per_rank,\n            num_topk,\n            num_experts,\n            rank,\n            num_ranks,\n            use_fp8,\n            round_scale,\n            use_ue8m0,\n            workspace,\n            num_device_sms,\n            launch_stream,\n            phases);\n    };\n    launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,\n        // so in Python API, we must wrap all tensors into the event handle.\n        event = EventHandle(launch_stream);\n    } else if (not return_recv_hook) {\n        stream_wait(compute_stream, launch_stream);\n    }\n\n    // Receiver callback\n    std::optional<std::function<void()>> recv_hook = std::nullopt;\n    if (return_recv_hook)\n        recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };\n\n    // Return values\n    return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n    return {};\n#endif\n}\n\nstd::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> Buffer::low_latency_combine(\n    const torch::Tensor& x,\n    const torch::Tensor& topk_idx,\n    const torch::Tensor& topk_weights,\n    const torch::Tensor& src_info,\n    const torch::Tensor& layout_range,\n    const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,\n    int num_max_dispatch_tokens_per_rank,\n    int num_experts,\n    bool use_logfmt,\n    bool zero_copy,\n    bool async,\n    bool return_recv_hook,\n    const std::optional<torch::Tensor>& out) {\n#ifndef DISABLE_NVSHMEM\n    EP_HOST_ASSERT(low_latency_mode);\n\n    // Tensor checks\n    EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);\n    EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);\n    EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank);\n    EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);\n    EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());\n    EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));\n    EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);\n    EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());\n    EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);\n    EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);\n    EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());\n    EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));\n    EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());\n    EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);\n    EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);\n\n    if (combine_wait_recv_cost_stats.has_value()) {\n        EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);\n        EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());\n        EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);\n    }\n\n    auto hidden = static_cast<int>(x.size(2));\n    auto num_topk = static_cast<int>(topk_weights.size(1));\n    auto num_combined_tokens = static_cast<int>(topk_weights.size(0));\n\n    // Buffer control\n    LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);\n    EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);\n    auto buffer = layout.buffers[low_latency_buffer_idx];\n    auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];\n\n    // Wait previous tasks to be finished\n    // NOTES: the hook mode will always use the default stream\n    auto compute_stream = at::cuda::getCurrentCUDAStream();\n    auto launch_stream = return_recv_hook ? compute_stream : comm_stream;\n    EP_HOST_ASSERT(not(async and return_recv_hook));\n    if (not return_recv_hook)\n        stream_wait(launch_stream, compute_stream);\n\n    // Allocate output tensor\n    torch::Tensor combined_x;\n    if (out.has_value()) {\n        EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());\n        EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);\n        EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());\n        combined_x = out.value();\n    } else {\n        combined_x = torch::empty({num_combined_tokens, hidden}, x.options());\n    }\n\n    // Kernel launch\n    auto next_clean_meta = next_buffer.clean_meta();\n    auto launcher = [=](int phases) {\n        internode_ll::combine(combined_x.data_ptr(),\n                              buffer.combine_rdma_recv_data_buffer,\n                              buffer.combine_rdma_recv_flag_buffer,\n                              buffer.combine_rdma_send_buffer,\n                              x.data_ptr(),\n                              topk_idx.data_ptr<topk_idx_t>(),\n                              topk_weights.data_ptr<float>(),\n                              src_info.data_ptr<int>(),\n                              layout_range.data_ptr<int64_t>(),\n                              mask_buffer_ptr,\n                              combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,\n                              next_clean_meta.first,\n                              next_clean_meta.second,\n                              num_combined_tokens,\n                              hidden,\n                              num_max_dispatch_tokens_per_rank,\n                              num_topk,\n                              num_experts,\n                              rank,\n                              num_ranks,\n                              use_logfmt,\n                              workspace,\n                              num_device_sms,\n                              launch_stream,\n                              phases,\n                              zero_copy);\n    };\n    launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));\n\n    // Wait streams\n    std::optional<EventHandle> event;\n    if (async) {\n        // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,\n        // so in Python API, we must wrap all tensors into the event handle.\n        event = EventHandle(launch_stream);\n    } else if (not return_recv_hook) {\n        stream_wait(compute_stream, launch_stream);\n    }\n\n    // Receiver callback\n    std::optional<std::function<void()>> recv_hook = std::nullopt;\n    if (return_recv_hook)\n        recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };\n\n    // Return values\n    return {combined_x, event, recv_hook};\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n    return {};\n#endif\n}\n\ntorch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {\n#ifndef DISABLE_NVSHMEM\n    LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);\n\n    auto buffer = layout.buffers[low_latency_buffer_idx];\n    auto dtype = torch::kBFloat16;\n    auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));\n\n    EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);\n    return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,\n                            {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},\n                            {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},\n                            torch::TensorOptions().dtype(dtype).device(torch::kCUDA));\n#else\n    EP_HOST_ASSERT(false and \"NVSHMEM is disabled during compilation\");\n    return {};\n#endif\n}\n\nbool is_sm90_compiled() {\n#ifndef DISABLE_SM90_FEATURES\n    return true;\n#else\n    return false;\n#endif\n}\n\nvoid Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) {\n    EP_HOST_ASSERT(mask_buffer_ptr != nullptr and \"Shrink mode must be enabled\");\n    EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks);\n    internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::cuda::getCurrentCUDAStream());\n}\n\nvoid Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) {\n    EP_HOST_ASSERT(mask_buffer_ptr != nullptr and \"Shrink mode must be enabled\");\n    EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32);\n\n    internode_ll::query_mask_buffer(\n        mask_buffer_ptr, num_ranks, reinterpret_cast<int*>(mask_status.data_ptr()), at::cuda::getCurrentCUDAStream());\n}\n\nvoid Buffer::low_latency_clean_mask_buffer() {\n    EP_HOST_ASSERT(mask_buffer_ptr != nullptr and \"Shrink mode must be enabled\");\n    internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::cuda::getCurrentCUDAStream());\n}\n\n}  // namespace deep_ep\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"DeepEP: an efficient expert-parallel communication library\";\n\n    pybind11::class_<deep_ep::Config>(m, \"Config\")\n        .def(pybind11::init<int, int, int, int, int>(),\n             py::arg(\"num_sms\") = 20,\n             py::arg(\"num_max_nvl_chunked_send_tokens\") = 6,\n             py::arg(\"num_max_nvl_chunked_recv_tokens\") = 256,\n             py::arg(\"num_max_rdma_chunked_send_tokens\") = 6,\n             py::arg(\"num_max_rdma_chunked_recv_tokens\") = 256)\n        .def(\"get_nvl_buffer_size_hint\", &deep_ep::Config::get_nvl_buffer_size_hint)\n        .def(\"get_rdma_buffer_size_hint\", &deep_ep::Config::get_rdma_buffer_size_hint);\n    m.def(\"get_low_latency_rdma_size_hint\", &deep_ep::get_low_latency_rdma_size_hint);\n\n    pybind11::class_<deep_ep::EventHandle>(m, \"EventHandle\")\n        .def(pybind11::init<>())\n        .def(\"current_stream_wait\", &deep_ep::EventHandle::current_stream_wait);\n\n    pybind11::class_<deep_ep::Buffer>(m, \"Buffer\")\n        .def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool>())\n        .def(\"is_available\", &deep_ep::Buffer::is_available)\n        .def(\"get_num_rdma_ranks\", &deep_ep::Buffer::get_num_rdma_ranks)\n        .def(\"get_rdma_rank\", &deep_ep::Buffer::get_rdma_rank)\n        .def(\"get_root_rdma_rank\", &deep_ep::Buffer::get_root_rdma_rank)\n        .def(\"get_local_device_id\", &deep_ep::Buffer::get_local_device_id)\n        .def(\"get_local_ipc_handle\", &deep_ep::Buffer::get_local_ipc_handle)\n        .def(\"get_local_nvshmem_unique_id\", &deep_ep::Buffer::get_local_nvshmem_unique_id)\n        .def(\"get_local_buffer_tensor\", &deep_ep::Buffer::get_local_buffer_tensor)\n        .def(\"get_comm_stream\", &deep_ep::Buffer::get_comm_stream)\n        .def(\"sync\", &deep_ep::Buffer::sync)\n        .def(\"destroy\", &deep_ep::Buffer::destroy)\n        .def(\"get_dispatch_layout\", &deep_ep::Buffer::get_dispatch_layout)\n        .def(\"intranode_dispatch\", &deep_ep::Buffer::intranode_dispatch)\n        .def(\"intranode_combine\", &deep_ep::Buffer::intranode_combine)\n        .def(\"internode_dispatch\", &deep_ep::Buffer::internode_dispatch)\n        .def(\"internode_combine\", &deep_ep::Buffer::internode_combine)\n        .def(\"clean_low_latency_buffer\", &deep_ep::Buffer::clean_low_latency_buffer)\n        .def(\"low_latency_dispatch\", &deep_ep::Buffer::low_latency_dispatch)\n        .def(\"low_latency_combine\", &deep_ep::Buffer::low_latency_combine)\n        .def(\"low_latency_update_mask_buffer\", &deep_ep::Buffer::low_latency_update_mask_buffer)\n        .def(\"low_latency_query_mask_buffer\", &deep_ep::Buffer::low_latency_query_mask_buffer)\n        .def(\"low_latency_clean_mask_buffer\", &deep_ep::Buffer::low_latency_clean_mask_buffer)\n        .def(\"get_next_low_latency_combine_buffer\", &deep_ep::Buffer::get_next_low_latency_combine_buffer);\n\n    m.def(\"is_sm90_compiled\", deep_ep::is_sm90_compiled);\n    m.attr(\"topk_idx_t\") =\n        py::reinterpret_borrow<py::object>((PyObject*)torch::getTHPDtype(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value));\n}\n"
  },
  {
    "path": "csrc/deep_ep.hpp",
    "content": "#pragma once\n\n// Forcibly disable NDEBUG\n#ifdef NDEBUG\n#undef NDEBUG\n#endif\n\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <torch/types.h>\n\n#include <tuple>\n#include <vector>\n\n#include \"config.hpp\"\n#include \"event.hpp\"\n#include \"kernels/configs.cuh\"\n#include \"kernels/exception.cuh\"\n\n#ifndef TORCH_EXTENSION_NAME\n#define TORCH_EXTENSION_NAME deep_ep_cpp\n#endif\n\nnamespace shared_memory {\n\nunion MemHandleInner {\n    cudaIpcMemHandle_t cuda_ipc_mem_handle;\n    CUmemFabricHandle cu_mem_fabric_handle;\n};\n\nstruct MemHandle {\n    MemHandleInner inner;\n    size_t size;\n};\n\nconstexpr size_t HANDLE_SIZE = sizeof(MemHandle);\n\nclass SharedMemoryAllocator {\npublic:\n    SharedMemoryAllocator(bool use_fabric);\n    void malloc(void** ptr, size_t size);\n    void free(void* ptr);\n    void get_mem_handle(MemHandle* mem_handle, void* ptr);\n    void open_mem_handle(void** ptr, MemHandle* mem_handle);\n    void close_mem_handle(void* ptr);\n\nprivate:\n    bool use_fabric;\n};\n}  // namespace shared_memory\n\nnamespace deep_ep {\n\nstruct Buffer {\n    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, \"The number of maximum NVLink peers must be 8\");\n\nprivate:\n    // Low-latency mode buffer\n    int low_latency_buffer_idx = 0;\n    bool low_latency_mode = false;\n\n    // NVLink Buffer\n    int64_t num_nvl_bytes;\n    void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};\n    void** buffer_ptrs_gpu = nullptr;\n\n    // NVSHMEM Buffer\n    int64_t num_rdma_bytes;\n    void* rdma_buffer_ptr = nullptr;\n\n    // Shrink mode buffer\n    bool enable_shrink = false;\n    int* mask_buffer_ptr = nullptr;\n    int* sync_buffer_ptr = nullptr;\n\n    // Device info and communication\n    int device_id;\n    int num_device_sms;\n    int rank, rdma_rank, nvl_rank;\n    int num_ranks, num_rdma_ranks, num_nvl_ranks;\n    shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];\n\n    // Stream for communication\n    at::cuda::CUDAStream comm_stream;\n\n    // After IPC/NVSHMEM synchronization, this flag will be true\n    bool available = false;\n\n    // Whether explicit `destroy()` is required.\n    bool explicitly_destroy;\n    // After `destroy()` be called, this flag will be true\n    bool destroyed = false;\n\n    // Barrier signals\n    int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};\n    int** barrier_signal_ptrs_gpu = nullptr;\n\n    // Workspace\n    void* workspace = nullptr;\n\n    // Host-side MoE info\n    volatile int* moe_recv_counter = nullptr;\n    int* moe_recv_counter_mapped = nullptr;\n\n    // Host-side expert-level MoE info\n    volatile int* moe_recv_expert_counter = nullptr;\n    int* moe_recv_expert_counter_mapped = nullptr;\n\n    // Host-side RDMA-level MoE info\n    volatile int* moe_recv_rdma_counter = nullptr;\n    int* moe_recv_rdma_counter_mapped = nullptr;\n\n    shared_memory::SharedMemoryAllocator shared_memory_allocator;\n\npublic:\n    Buffer(int rank,\n           int num_ranks,\n           int64_t num_nvl_bytes,\n           int64_t num_rdma_bytes,\n           bool low_latency_mode,\n           bool explicitly_destroy,\n           bool enable_shrink,\n           bool use_fabric);\n\n    ~Buffer() noexcept(false);\n\n    bool is_available() const;\n\n    bool is_internode_available() const;\n\n    int get_num_rdma_ranks() const;\n\n    int get_rdma_rank() const;\n\n    int get_root_rdma_rank(bool global) const;\n\n    int get_local_device_id() const;\n\n    pybind11::bytearray get_local_ipc_handle() const;\n\n    pybind11::bytearray get_local_nvshmem_unique_id() const;\n\n    torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;\n\n    torch::Stream get_comm_stream() const;\n\n    void sync(const std::vector<int>& device_ids,\n              const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles,\n              const std::optional<pybind11::bytearray>& root_unique_id_opt);\n\n    void destroy();\n\n    std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>> get_dispatch_layout(\n        const torch::Tensor& topk_idx,\n        int num_experts,\n        std::optional<EventHandle>& previous_event,\n        bool async,\n        bool allocate_on_comm_stream);\n\n    std::tuple<torch::Tensor,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::vector<int>,\n               torch::Tensor,\n               torch::Tensor,\n               torch::Tensor,\n               torch::Tensor,\n               torch::Tensor,\n               std::optional<EventHandle>>\n    intranode_dispatch(const torch::Tensor& x,\n                       const std::optional<torch::Tensor>& x_scales,\n                       const std::optional<torch::Tensor>& topk_idx,\n                       const std::optional<torch::Tensor>& topk_weights,\n                       const std::optional<torch::Tensor>& num_tokens_per_rank,\n                       const torch::Tensor& is_token_in_rank,\n                       const std::optional<torch::Tensor>& num_tokens_per_expert,\n                       int cached_num_recv_tokens,\n                       const std::optional<torch::Tensor>& cached_rank_prefix_matrix,\n                       const std::optional<torch::Tensor>& cached_channel_prefix_matrix,\n                       int expert_alignment,\n                       int num_worst_tokens,\n                       const Config& config,\n                       std::optional<EventHandle>& previous_event,\n                       bool async,\n                       bool allocate_on_comm_stream);\n\n    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> intranode_combine(\n        const torch::Tensor& x,\n        const std::optional<torch::Tensor>& topk_weights,\n        const std::optional<torch::Tensor>& bias_0,\n        const std::optional<torch::Tensor>& bias_1,\n        const torch::Tensor& src_idx,\n        const torch::Tensor& rank_prefix_matrix,\n        const torch::Tensor& channel_prefix_matrix,\n        const torch::Tensor& send_head,\n        const Config& config,\n        std::optional<EventHandle>& previous_event,\n        bool async,\n        bool allocate_on_comm_stream);\n\n    std::tuple<torch::Tensor,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::vector<int>,\n               torch::Tensor,\n               torch::Tensor,\n               std::optional<torch::Tensor>,\n               torch::Tensor,\n               std::optional<torch::Tensor>,\n               torch::Tensor,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::optional<torch::Tensor>,\n               std::optional<EventHandle>>\n    internode_dispatch(const torch::Tensor& x,\n                       const std::optional<torch::Tensor>& x_scales,\n                       const std::optional<torch::Tensor>& topk_idx,\n                       const std::optional<torch::Tensor>& topk_weights,\n                       const std::optional<torch::Tensor>& num_tokens_per_rank,\n                       const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,\n                       const torch::Tensor& is_token_in_rank,\n                       const std::optional<torch::Tensor>& num_tokens_per_expert,\n                       int cached_num_recv_tokens,\n                       int cached_num_rdma_recv_tokens,\n                       const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix,\n                       const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,\n                       const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,\n                       const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,\n                       int expert_alignment,\n                       int num_worst_tokens,\n                       const Config& config,\n                       std::optional<EventHandle>& previous_event,\n                       bool async,\n                       bool allocate_on_comm_stream);\n\n    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> internode_combine(\n        const torch::Tensor& x,\n        const std::optional<torch::Tensor>& topk_weights,\n        const std::optional<torch::Tensor>& bias_0,\n        const std::optional<torch::Tensor>& bias_1,\n        const torch::Tensor& src_meta,\n        const torch::Tensor& is_combined_token_in_rank,\n        const torch::Tensor& rdma_channel_prefix_matrix,\n        const torch::Tensor& rdma_rank_prefix_sum,\n        const torch::Tensor& gbl_channel_prefix_matrix,\n        const torch::Tensor& combined_rdma_head,\n        const torch::Tensor& combined_nvl_head,\n        const Config& config,\n        std::optional<EventHandle>& previous_event,\n        bool async,\n        bool allocate_on_comm_stream);\n\n    void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);\n\n    std::tuple<torch::Tensor,\n               std::optional<torch::Tensor>,\n               torch::Tensor,\n               torch::Tensor,\n               torch::Tensor,\n               std::optional<EventHandle>,\n               std::optional<std::function<void()>>>\n    low_latency_dispatch(const torch::Tensor& x,\n                         const torch::Tensor& topk_idx,\n                         const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,\n                         const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,\n                         int num_max_dispatch_tokens_per_rank,\n                         int num_experts,\n                         bool use_fp8,\n                         bool round_scale,\n                         bool use_ue8m0,\n                         bool async,\n                         bool return_recv_hook);\n\n    std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> low_latency_combine(\n        const torch::Tensor& x,\n        const torch::Tensor& topk_idx,\n        const torch::Tensor& topk_weights,\n        const torch::Tensor& src_info,\n        const torch::Tensor& layout_range,\n        const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,\n        int num_max_dispatch_tokens_per_rank,\n        int num_experts,\n        bool use_logfmt,\n        bool zero_copy,\n        bool async,\n        bool return_recv_hook,\n        const std::optional<torch::Tensor>& out = std::nullopt);\n\n    torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;\n\n    void low_latency_update_mask_buffer(int rank_to_mask, bool mask);\n\n    void low_latency_query_mask_buffer(const torch::Tensor& mask_status);\n\n    void low_latency_clean_mask_buffer();\n};\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/event.hpp",
    "content": "#include <ATen/cuda/CUDAContext.h>\n\n#include <memory>\n\n#include \"kernels/exception.cuh\"\n\nnamespace deep_ep {\n\nstruct EventHandle {\n    std::shared_ptr<torch::Event> event;\n\n    EventHandle() {\n        event = std::make_shared<torch::Event>(torch::kCUDA);\n        event->record(at::cuda::getCurrentCUDAStream());\n    }\n\n    explicit EventHandle(const at::cuda::CUDAStream& stream) {\n        event = std::make_shared<torch::Event>(torch::kCUDA);\n        event->record(stream);\n    }\n\n    EventHandle(const EventHandle& other) = default;\n\n    void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); }\n};\n\ntorch::Event create_event(const at::cuda::CUDAStream& s) {\n    auto event = torch::Event(torch::kCUDA);\n    event.record(s);\n    return event;\n}\n\nvoid stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {\n    EP_HOST_ASSERT(s_0.id() != s_1.id());\n    s_0.unwrap().wait(create_event(s_1));\n}\n\nvoid stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {\n    s.unwrap().wait(*event.event);\n}\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/CMakeLists.txt",
    "content": "function(add_deep_ep_library target_name source_file)\n    add_library(${target_name} STATIC ${source_file})\n    set_target_properties(${target_name} PROPERTIES\n            POSITION_INDEPENDENT_CODE ON\n            CXX_STANDARD_REQUIRED ON\n            CUDA_STANDARD_REQUIRED ON\n            CXX_STANDARD 17\n            CUDA_STANDARD 17\n            CUDA_SEPARABLE_COMPILATION ON\n    )\n    target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)\nendfunction()\n\nadd_deep_ep_library(runtime_cuda runtime.cu)\nadd_deep_ep_library(layout_cuda layout.cu)\nadd_deep_ep_library(intranode_cuda intranode.cu)\nadd_deep_ep_library(internode_cuda internode.cu)\nadd_deep_ep_library(internode_ll_cuda internode_ll.cu)\n\n# Later, we should link all libraries in `EP_CUDA_LIBRARIES`\nset(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)\n"
  },
  {
    "path": "csrc/kernels/api.cuh",
    "content": "#pragma once\n\n#include <vector>\n\n#include \"configs.cuh\"\n\nnamespace deep_ep {\n\n// Intranode runtime\nnamespace intranode {\n\nvoid barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);\n\n}  // namespace intranode\n\n// Internode runtime\nnamespace internode {\n\nstd::vector<uint8_t> get_unique_id();\n\nint init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);\n\nvoid* alloc(size_t size, size_t alignment);\n\nvoid free(void* ptr);\n\nvoid barrier();\n\nvoid finalize();\n\n}  // namespace internode\n\n// Layout kernels\nnamespace layout {\n\nvoid get_dispatch_layout(const topk_idx_t* topk_idx,\n                         int* num_tokens_per_rank,\n                         int* num_tokens_per_rdma_rank,\n                         int* num_tokens_per_expert,\n                         bool* is_token_in_rank,\n                         int num_tokens,\n                         int num_topk,\n                         int num_ranks,\n                         int num_experts,\n                         cudaStream_t stream);\n\n}  // namespace layout\n\n// Intranode kernels\nnamespace intranode {\n\nvoid notify_dispatch(const int* num_tokens_per_rank,\n                     int* moe_recv_counter_mapped,\n                     int num_ranks,\n                     const int* num_tokens_per_expert,\n                     int* moe_recv_expert_counter_mapped,\n                     int num_experts,\n                     int num_tokens,\n                     const bool* is_token_in_rank,\n                     int* channel_prefix_matrix,\n                     int* rank_prefix_matrix_copy,\n                     int num_memset_int,\n                     int expert_alignment,\n                     void** buffer_ptrs,\n                     int** barrier_signal_ptrs,\n                     int rank,\n                     cudaStream_t stream,\n                     int num_sms);\n\nvoid cached_notify_dispatch(const int* rank_prefix_matrix,\n                            int num_memset_int,\n                            void** buffer_ptrs,\n                            int** barrier_signal_ptrs,\n                            int rank,\n                            int num_ranks,\n                            cudaStream_t stream);\n\nvoid dispatch(void* recv_x,\n              float* recv_x_scales,\n              int* recv_src_idx,\n              topk_idx_t* recv_topk_idx,\n              float* recv_topk_weights,\n              int* recv_channel_offset,\n              int* send_head,\n              const void* x,\n              const float* x_scales,\n              const topk_idx_t* topk_idx,\n              const float* topk_weights,\n              const bool* is_token_in_rank,\n              const int* channel_prefix_matrix,\n              int num_tokens,\n              int num_worst_tokens,\n              int hidden_int4,\n              int num_topk,\n              int num_experts,\n              int num_scales,\n              int scale_token_stride,\n              int scale_hidden_stride,\n              void** buffer_ptrs,\n              int rank,\n              int num_ranks,\n              cudaStream_t stream,\n              int num_sms,\n              int num_max_send_tokens,\n              int num_recv_buffer_tokens);\n\nvoid cached_notify_combine(void** buffer_ptrs,\n                           int* send_head,\n                           int num_channels,\n                           int num_recv_tokens,\n                           int num_memset_int,\n                           int** barrier_signal_ptrs,\n                           int rank,\n                           int num_ranks,\n                           cudaStream_t stream);\n\nvoid combine(cudaDataType_t type,\n             void* recv_x,\n             float* recv_topk_weights,\n             const void* x,\n             const float* topk_weights,\n             const void* bias_0,\n             const void* bias_1,\n             const int* src_idx,\n             const int* rank_prefix_matrix,\n             const int* channel_prefix_matrix,\n             int* send_head,\n             int num_tokens,\n             int num_recv_tokens,\n             int hidden,\n             int num_topk,\n             void** buffer_ptrs,\n             int rank,\n             int num_ranks,\n             cudaStream_t stream,\n             int num_sms,\n             int num_max_send_tokens,\n             int num_recv_buffer_tokens);\n\n}  // namespace intranode\n\n// Internode kernels\nnamespace internode {\n\nint get_source_meta_bytes();\n\nvoid notify_dispatch(const int* num_tokens_per_rank,\n                     int* moe_recv_counter_mapped,\n                     int num_ranks,\n                     const int* num_tokens_per_rdma_rank,\n                     int* moe_recv_rdma_counter_mapped,\n                     const int* num_tokens_per_expert,\n                     int* moe_recv_expert_counter_mapped,\n                     int num_experts,\n                     const bool* is_token_in_rank,\n                     int num_tokens,\n                     int num_worst_tokens,\n                     int num_channels,\n                     int hidden_int4,\n                     int num_scales,\n                     int num_topk,\n                     int expert_alignment,\n                     int* rdma_channel_prefix_matrix,\n                     int* recv_rdma_rank_prefix_sum,\n                     int* gbl_channel_prefix_matrix,\n                     int* recv_gbl_rank_prefix_sum,\n                     void* rdma_buffer_ptr,\n                     int num_max_rdma_chunked_recv_tokens,\n                     void** buffer_ptrs,\n                     int num_max_nvl_chunked_recv_tokens,\n                     int** barrier_signal_ptrs,\n                     int rank,\n                     cudaStream_t stream,\n                     int64_t num_rdma_bytes,\n                     int64_t num_nvl_bytes,\n                     bool low_latency_mode);\n\nvoid dispatch(void* recv_x,\n              float* recv_x_scales,\n              topk_idx_t* recv_topk_idx,\n              float* recv_topk_weights,\n              void* recv_src_meta,\n              const void* x,\n              const float* x_scales,\n              const topk_idx_t* topk_idx,\n              const float* topk_weights,\n              int* send_rdma_head,\n              int* send_nvl_head,\n              int* recv_rdma_channel_prefix_matrix,\n              int* recv_gbl_channel_prefix_matrix,\n              const int* rdma_channel_prefix_matrix,\n              const int* recv_rdma_rank_prefix_sum,\n              const int* gbl_channel_prefix_matrix,\n              const int* recv_gbl_rank_prefix_sum,\n              const bool* is_token_in_rank,\n              int num_tokens,\n              int num_worst_tokens,\n              int hidden_int4,\n              int num_scales,\n              int num_topk,\n              int num_experts,\n              int scale_token_stride,\n              int scale_hidden_stride,\n              void* rdma_buffer_ptr,\n              int num_max_rdma_chunked_send_tokens,\n              int num_max_rdma_chunked_recv_tokens,\n              void** buffer_ptrs,\n              int num_max_nvl_chunked_send_tokens,\n              int num_max_nvl_chunked_recv_tokens,\n              int rank,\n              int num_ranks,\n              bool is_cached_dispatch,\n              cudaStream_t stream,\n              int num_channels,\n              bool low_latency_mode);\n\nvoid cached_notify(int hidden_int4,\n                   int num_scales,\n                   int num_topk_idx,\n                   int num_topk_weights,\n                   int num_ranks,\n                   int num_channels,\n                   int num_combined_tokens,\n                   int* combined_rdma_head,\n                   const int* rdma_channel_prefix_matrix,\n                   const int* rdma_rank_prefix_sum,\n                   int* combined_nvl_head,\n                   void* rdma_buffer_ptr,\n                   int num_max_rdma_chunked_recv_tokens,\n                   void** buffer_ptrs,\n                   int num_max_nvl_chunked_recv_tokens,\n                   int** barrier_signal_ptrs,\n                   int rank,\n                   cudaStream_t stream,\n                   int64_t num_rdma_bytes,\n                   int64_t num_nvl_bytes,\n                   bool is_cached_dispatch,\n                   bool low_latency_mode);\n\nvoid combine(cudaDataType_t type,\n             void* combined_x,\n             float* combined_topk_weights,\n             const bool* is_combined_token_in_rank,\n             const void* x,\n             const float* topk_weights,\n             const void* bias_0,\n             const void* bias_1,\n             const int* combined_rdma_head,\n             const int* combined_nvl_head,\n             const void* src_meta,\n             const int* rdma_channel_prefix_matrix,\n             const int* rdma_rank_prefix_sum,\n             const int* gbl_channel_prefix_matrix,\n             int num_tokens,\n             int num_combined_tokens,\n             int hidden,\n             int num_topk,\n             void* rdma_buffer_ptr,\n             int num_max_rdma_chunked_send_tokens,\n             int num_max_rdma_chunked_recv_tokens,\n             void** buffer_ptrs,\n             int num_max_nvl_chunked_send_tokens,\n             int num_max_nvl_chunked_recv_tokens,\n             int rank,\n             int num_ranks,\n             cudaStream_t stream,\n             int num_channels,\n             bool low_latency_mode);\n\n}  // namespace internode\n\n// Internode low-latency kernels\nnamespace internode_ll {\n\nvoid clean_low_latency_buffer(int* clean_0,\n                              int num_clean_int_0,\n                              int* clean_1,\n                              int num_clean_int_1,\n                              int rank,\n                              int num_ranks,\n                              int* mask_buffer,\n                              int* sync_buffer,\n                              cudaStream_t stream);\n\nvoid dispatch(void* packed_recv_x,\n              void* packed_recv_x_scales,\n              int* packed_recv_src_info,\n              int64_t* packed_recv_layout_range,\n              int* packed_recv_count,\n              int* mask_buffer,\n              int* cumulative_local_expert_recv_stats,\n              int64_t* dispatch_wait_recv_cost_stats,\n              void* rdma_recv_x,\n              int* rdma_recv_count,\n              void* rdma_x,\n              const void* x,\n              const topk_idx_t* topk_idx,\n              int* next_clean,\n              int num_next_clean_int,\n              int num_tokens,\n              int hidden,\n              int num_max_dispatch_tokens_per_rank,\n              int num_topk,\n              int num_experts,\n              int rank,\n              int num_ranks,\n              bool use_fp8,\n              bool round_scale,\n              bool use_ue8m0,\n              void* workspace,\n              int num_device_sms,\n              cudaStream_t stream,\n              int phases);\n\nvoid combine(void* combined_x,\n             void* rdma_recv_x,\n             int* rdma_recv_flag,\n             void* rdma_send_x,\n             const void* x,\n             const topk_idx_t* topk_idx,\n             const float* topk_weights,\n             const int* src_info,\n             const int64_t* layout_range,\n             int* mask_buffer,\n             int64_t* combine_wait_recv_cost_stats,\n             int* next_clean,\n             int num_next_clean_int,\n             int num_combined_tokens,\n             int hidden,\n             int num_max_dispatch_tokens_per_rank,\n             int num_topk,\n             int num_experts,\n             int rank,\n             int num_ranks,\n             bool use_logfmt,\n             void* workspace,\n             int num_device_sms,\n             cudaStream_t stream,\n             int phases,\n             bool zero_copy);\n\nvoid query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream);\n\nvoid update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream);\n\nvoid clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream);\n\n}  // namespace internode_ll\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/buffer.cuh",
    "content": "#pragma once\n\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n\nnamespace deep_ep {\n\ntemplate <typename dtype_t>\nstruct Buffer {\nprivate:\n    uint8_t* ptr;\n\npublic:\n    int64_t total_bytes;\n\n    __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}\n\n    __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) {\n        total_bytes = num_elems * sizeof(dtype_t);\n        ptr = static_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);\n        gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;\n    }\n\n    __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) {\n        gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;\n        return *this;\n    }\n\n    __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast<dtype_t*>(ptr); }\n\n    __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; }\n};\n\ntemplate <typename dtype_t, int kNumRanks = 1>\nstruct AsymBuffer {\nprivate:\n    uint8_t* ptrs[kNumRanks];\n    int64_t num_bytes;\n\npublic:\n    int64_t total_bytes;\n\n    __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {\n        EP_STATIC_ASSERT(kNumRanks == 1, \"\");\n        num_bytes = num_elems * sizeof(dtype_t);\n\n        int64_t per_channel_bytes = num_bytes * num_ranks;\n        total_bytes = per_channel_bytes * num_sms;\n        ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;\n        gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;\n    }\n\n    __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {\n        EP_STATIC_ASSERT(kNumRanks > 1, \"\");\n        num_bytes = num_elems * sizeof(dtype_t);\n\n        int64_t per_channel_bytes = num_bytes * num_ranks;\n        total_bytes = per_channel_bytes * num_sms;\n        for (int i = 0; i < kNumRanks; ++i) {\n            ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;\n            gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;\n        }\n    }\n\n    __device__ __forceinline__ void advance(int shift) {\n        #pragma unroll\n        for (int i = 0; i < kNumRanks; ++i)\n            ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);\n    }\n\n    __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) {\n        gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;\n        return *this;\n    }\n\n    template <int kNumAlsoRanks>\n    __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {\n        for (int i = 0; i < kNumAlsoRanks; ++i)\n            gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;\n        return *this;\n    }\n\n    __device__ __forceinline__ dtype_t* buffer(int idx = 0) {\n        EP_STATIC_ASSERT(kNumRanks == 1, \"`buffer` is only available for single rank case\");\n        return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);\n    }\n\n    __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {\n        EP_STATIC_ASSERT(kNumRanks > 1, \"`buffer` is only available for single rank case\");\n        return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);\n    }\n};\n\ntemplate <typename dtype_t, bool kDecoupled = true>\nstruct SymBuffer {\nprivate:\n    // NOTES: for non-decoupled case, `recv_ptr` is not used\n    uint8_t* send_ptr;\n    uint8_t* recv_ptr;\n    int64_t num_bytes;\n\npublic:\n    int64_t total_bytes;\n\n    __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) {\n        num_bytes = num_elems * sizeof(dtype_t);\n\n        int64_t per_channel_bytes = num_bytes * num_ranks;\n        total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);\n        send_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;\n        recv_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);\n        gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;\n    }\n\n    __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {\n        EP_STATIC_ASSERT(kDecoupled, \"`send_buffer` is only available for non-decoupled case\");\n        return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);\n    }\n\n    __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {\n        EP_STATIC_ASSERT(kDecoupled, \"`recv_buffer` is only available for non-decoupled case\");\n        return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);\n    }\n\n    __device__ __forceinline__ dtype_t* buffer(int idx = 0) {\n        EP_STATIC_ASSERT(not kDecoupled, \"`buffer` is only available for decoupled case\");\n        return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);\n    }\n};\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/configs.cuh",
    "content": "#pragma once\n\n#define NUM_MAX_NVL_PEERS 8\n#define NUM_MAX_RDMA_PEERS 20\n#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)\n#define NUM_MAX_LOCAL_EXPERTS 1024\n#define NUM_BUFFER_ALIGNMENT_BYTES 128\n\n#define FINISHED_SUM_TAG 1024\n#define NUM_WAIT_NANOSECONDS 500\n\n#ifndef ENABLE_FAST_DEBUG\n#define NUM_CPU_TIMEOUT_SECS 100\n#define NUM_TIMEOUT_CYCLES 200000000000ull  // 200G cycles ~= 100s\n#else\n#define NUM_CPU_TIMEOUT_SECS 10\n#define NUM_TIMEOUT_CYCLES 20000000000ull  // 20G cycles ~= 10s\n#endif\n\n#define LOW_LATENCY_SEND_PHASE 1\n#define LOW_LATENCY_RECV_PHASE 2\n\n// Make CLion CUDA indexing work\n#ifdef __CLION_IDE__\n#define __CUDA_ARCH__ 900  // NOLINT(*-reserved-identifier)\n#define __CUDACC_RDC__     // NOLINT(*-reserved-identifier)\n#endif\n\n// Define __CUDACC_RDC__ to ensure proper extern declarations for NVSHMEM device symbols\n#ifndef DISABLE_NVSHMEM\n#ifndef __CUDACC_RDC__\n#define __CUDACC_RDC__  // NOLINT(*-reserved-identifier)\n#endif\n#endif\n\n// Remove Torch restrictions\n#ifdef __CUDA_NO_HALF_CONVERSIONS__\n#undef __CUDA_NO_HALF_CONVERSIONS__\n#endif\n#ifdef __CUDA_NO_HALF_OPERATORS__\n#undef __CUDA_NO_HALF_OPERATORS__\n#endif\n#ifdef __CUDA_NO_HALF2_OPERATORS__\n#undef __CUDA_NO_HALF2_OPERATORS__\n#endif\n#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__\n#undef __CUDA_NO_BFLOAT16_CONVERSIONS__\n#endif\n#ifdef __CUDA_NO_BFLOAT162_OPERATORS__\n#undef __CUDA_NO_BFLOAT162_OPERATORS__\n#endif\n\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n\n#include <cstdint>\n\n#ifndef DISABLE_SM90_FEATURES\n#include <cuda_fp8.h>\n#else\n// Ampere does not support FP8 features\n#define __NV_E4M3 0\n#define __NV_E5M2 1\ntypedef int __nv_fp8_interpretation_t;\ntypedef int __nv_fp8x4_e4m3;\ntypedef uint8_t __nv_fp8_storage_t;\n#endif\n\nnamespace deep_ep {\n\n#ifndef TOPK_IDX_BITS\n#define TOPK_IDX_BITS 64\n#endif\n\n#define INT_BITS_T2(bits) int##bits##_t\n#define INT_BITS_T(bits) INT_BITS_T2(bits)\ntypedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t;  // int32_t or int64_t\n#undef INT_BITS_T\n#undef INT_BITS_T2\n\n}  // namespace deep_ep\n\n#ifndef DISABLE_NVSHMEM\n#include <device_host_transport/nvshmem_common_ibgda.h>\n#include <infiniband/mlx5dv.h>\n#include <nvshmem.h>\n#include <nvshmemx.h>\n\n#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>\n#endif\n"
  },
  {
    "path": "csrc/kernels/exception.cuh",
    "content": "#pragma once\n\n#include <exception>\n#include <string>\n\n#include \"configs.cuh\"\n\n#ifndef EP_STATIC_ASSERT\n#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)\n#endif\n\nclass EPException : public std::exception {\nprivate:\n    std::string message = {};\n\npublic:\n    explicit EPException(const char* name, const char* file, const int line, const std::string& error) {\n        message = std::string(\"Failed: \") + name + \" error \" + file + \":\" + std::to_string(line) + \" '\" + error + \"'\";\n    }\n\n    const char* what() const noexcept override { return message.c_str(); }\n};\n\n#ifndef CUDA_CHECK\n#define CUDA_CHECK(cmd)                                                           \\\n    do {                                                                          \\\n        cudaError_t e = (cmd);                                                    \\\n        if (e != cudaSuccess) {                                                   \\\n            throw EPException(\"CUDA\", __FILE__, __LINE__, cudaGetErrorString(e)); \\\n        }                                                                         \\\n    } while (0)\n#endif\n\n#ifndef CU_CHECK\n#define CU_CHECK(cmd)                                                            \\\n    do {                                                                         \\\n        CUresult e = (cmd);                                                      \\\n        if (e != CUDA_SUCCESS) {                                                 \\\n            const char* error_str = NULL;                                        \\\n            cuGetErrorString(e, &error_str);                                     \\\n            throw EPException(\"CU\", __FILE__, __LINE__, std::string(error_str)); \\\n        }                                                                        \\\n    } while (0)\n#endif\n\n#ifndef EP_HOST_ASSERT\n#define EP_HOST_ASSERT(cond)                                           \\\n    do {                                                               \\\n        if (not(cond)) {                                               \\\n            throw EPException(\"Assertion\", __FILE__, __LINE__, #cond); \\\n        }                                                              \\\n    } while (0)\n#endif\n\n#ifndef EP_DEVICE_ASSERT\n#define EP_DEVICE_ASSERT(cond)                                                             \\\n    do {                                                                                   \\\n        if (not(cond)) {                                                                   \\\n            printf(\"Assertion failed: %s:%d, condition: %s\\n\", __FILE__, __LINE__, #cond); \\\n            asm(\"trap;\");                                                                  \\\n        }                                                                                  \\\n    } while (0)\n#endif\n"
  },
  {
    "path": "csrc/kernels/ibgda_device.cuh",
    "content": "// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)\n// Copyright (c) NVIDIA Corporation.\n// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).\n// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html\n//\n// Modified from original source:\n//  - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh\n#pragma once\n\n#include <type_traits>\n\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"utils.cuh\"\n\nnamespace deep_ep {\n\nEP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, \"Invalid QP minimum depth\");\n\n__device__ static __forceinline__ uint64_t HtoBE64(uint64_t x) {\n    uint64_t ret;\n    asm(\"{\\n\\t\"\n        \".reg .b32 ign;\\n\\t\"\n        \".reg .b32 lo;\\n\\t\"\n        \".reg .b32 hi;\\n\\t\"\n        \".reg .b32 new_lo;\\n\\t\"\n        \".reg .b32 new_hi;\\n\\t\"\n        \"mov.b64 {lo,hi}, %1;\\n\\t\"\n        \"prmt.b32 new_hi, lo, ign, 0x0123;\\n\\t\"\n        \"prmt.b32 new_lo, hi, ign, 0x0123;\\n\\t\"\n        \"mov.b64 %0, {new_lo,new_hi};\\n\\t\"\n        \"}\"\n        : \"=l\"(ret)\n        : \"l\"(x));\n    return ret;\n}\n\n__device__ static __forceinline__ uint32_t HtoBE32(uint32_t x) {\n    uint32_t ret;\n    asm(\"{\\n\\t\"\n        \".reg .b32 ign;\\n\\t\"\n        \"prmt.b32 %0, %1, ign, 0x0123;\\n\\t\"\n        \"}\"\n        : \"=r\"(ret)\n        : \"r\"(x));\n    return ret;\n}\n\n__device__ static __forceinline__ uint16_t HtoBE16(uint16_t x) {\n    // TODO: simplify PTX using 16-bit instructions\n    auto a = static_cast<uint32_t>(x);\n    uint32_t d;\n    asm volatile(\n        \"{\\n\\t\"\n        \".reg .b32 mask;\\n\\t\"\n        \".reg .b32 ign;\\n\\t\"\n        \"mov.b32 mask, 0x4401;\\n\\t\"\n        \"mov.b32 ign, 0x0;\\n\\t\"\n        \"prmt.b32 %0, %1, ign, mask;\\n\\t\"\n        \"}\"\n        : \"=r\"(d)\n        : \"r\"(a));\n    return static_cast<uint16_t>(d);\n}\n\ntypedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;\n\ntypedef struct {\n    uint32_t add_data;\n    uint32_t field_boundary;\n    uint64_t reserved;\n} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t;\n\n__device__ static __forceinline__ nvshmemi_ibgda_device_state_t* ibgda_get_state() {\n    return &nvshmemi_ibgda_device_state_d;\n}\n\n// Template helper to get RC - uses compile-time type checking with if constexpr (C++17)\ntemplate <typename StateType>\n__device__ static __forceinline__ nvshmemi_ibgda_device_qp_t* ibgda_get_rc_impl(StateType* state, int pe, int id) {\n    const auto num_rc_per_pe = state->num_rc_per_pe;\n\n    if constexpr (std::is_same_v<StateType, nvshmemi_ibgda_device_state_v1>) {\n        // v1 implementation\n        return &state->globalmem\n                    .rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)];\n    } else {\n        // v2 implementation (or any other type)\n        return &state->globalmem.rcs[pe + nvshmemi_device_state_d.npes * id];\n    }\n}\n\n__device__ static __forceinline__ nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {\n    auto state = ibgda_get_state();\n    return ibgda_get_rc_impl(state, pe, id);\n}\n\n__device__ static __forceinline__ void ibgda_lock_acquire(int* lock) {\n    while (atomicCAS(lock, 0, 1) == 1)\n        ;\n\n    // Prevent reordering before the lock is acquired\n    memory_fence_cta();\n}\n\n__device__ static __forceinline__ void ibgda_lock_release(int* lock) {\n    memory_fence_cta();\n\n    // Prevent reordering before lock is released\n    st_na_relaxed(lock, 0);\n}\n\n__device__ static __forceinline__ void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t* qp, uint32_t dbrec_head) {\n    // `DBREC` contains the index of the next empty `WQEBB`\n    __be32 dbrec_val;\n    __be32* dbrec_ptr = qp->tx_wq.dbrec;\n\n    // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`\n    asm(\"{\\n\\t\"\n        \".reg .b32 dbrec_head_16b;\\n\\t\"\n        \".reg .b32 ign;\\n\\t\"\n        \"and.b32 dbrec_head_16b, %1, 0xffff;\\n\\t\"\n        \"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\\n\\t\"\n        \"}\"\n        : \"=r\"(dbrec_val)\n        : \"r\"(dbrec_head));\n    st_na_release(dbrec_ptr, dbrec_val);\n}\n\n__device__ static __forceinline__ void ibgda_ring_db(nvshmemi_ibgda_device_qp_t* qp, uint16_t prod_idx) {\n    auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);\n    ibgda_ctrl_seg_t ctrl_seg = {.opmod_idx_opcode = HtoBE32(prod_idx << 8), .qpn_ds = HtoBE32(qp->qpn << 8)};\n\n    EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), \"\");\n    st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));\n}\n\n__device__ static __forceinline__ void ibgda_post_send(nvshmemi_ibgda_device_qp_t* qp, uint64_t new_prod_idx) {\n    nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars;\n    uint64_t old_prod_idx;\n\n    // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence\n    ibgda_lock_acquire(&mvars->post_send_lock);\n\n    old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);\n    if (new_prod_idx > old_prod_idx) {\n        ibgda_update_dbr(qp, new_prod_idx);\n        ibgda_ring_db(qp, new_prod_idx);\n    }\n    ibgda_lock_release(&mvars->post_send_lock);\n}\n\ntemplate <bool kAlwaysDoPostSend>\n__device__ static __forceinline__ void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t* qp,\n                                                             uint64_t base_wqe_idx,\n                                                             uint32_t num_wqes,\n                                                             int message_idx = 0) {\n    auto state = ibgda_get_state();\n    nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars;\n    uint64_t new_wqe_idx = base_wqe_idx + num_wqes;\n\n    // WQE writes must be finished first\n    __threadfence();\n\n    unsigned long long int* ready_idx =\n        (unsigned long long int*)(state->use_async_postsend ? qp->tx_wq.prod_idx : &mvars->tx_wq.ready_head);\n\n    // Wait for prior WQE slots to be filled first\n    while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx)\n        ;\n\n    // Always post, not in batch\n    if (!state->use_async_postsend) {\n        constexpr int kNumRequestInBatch = 4;\n        if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)\n            ibgda_post_send(qp, new_wqe_idx);\n    }\n}\n\n__device__ static __forceinline__ void ibgda_write_rdma_write_inl_wqe(\n    nvshmemi_ibgda_device_qp_t* qp, const uint32_t* val, uint64_t raddr, __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) {\n    ibgda_ctrl_seg_t ctrl_seg;\n    struct mlx5_wqe_raddr_seg raddr_seg;\n    struct mlx5_wqe_inl_data_seg inl_seg;\n\n    auto* ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);\n    auto* raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));\n    auto* inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));\n    auto* wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));\n\n    raddr_seg.raddr = HtoBE64(raddr);\n    raddr_seg.rkey = rkey;\n    raddr_seg.reserved = 0;\n\n    inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);\n\n    // `imm == std::numeric_limits<uint32_t>::max()` means no imm writes\n    ctrl_seg = {0};\n    ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);\n    ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;\n    ctrl_seg.opmod_idx_opcode =\n        HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));\n    if (imm != std::numeric_limits<uint32_t>::max())\n        ctrl_seg.imm = HtoBE32(imm);\n\n    EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, \"sizeof(*ctrl_seg_ptr) == 16\");\n    EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, \"sizeof(*raddr_seg_ptr) == 16\");\n    EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, \"sizeof(*inl_seg_ptr) == 4\");\n    st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));\n    st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));\n    st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));\n    st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));\n}\n\n__device__ static __forceinline__ uint64_t\nibgda_get_lkey_and_rkey(uint64_t laddr, __be32* lkey, uint64_t raddr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) {\n    auto state = ibgda_get_state();\n    auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);\n    auto log2_cumem_granularity = state->log2_cumem_granularity;\n\n    // Local key\n    uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx;\n    auto device_key = state->constmem.lkeys[idx];\n    auto lchunk_size = device_key.next_addr - laddr;\n    *lkey = device_key.key;\n\n    // Remote key\n    uint64_t roffset = raddr - heap_start;\n\n    idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized +\n        dst_pe * state->num_devices_initialized + dev_idx;\n    if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {\n        device_key = state->constmem.rkeys[idx];\n    } else {\n        device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];\n    }\n    *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;\n    *out_rkey = device_key.key;\n\n    // Return the minimum of local and remote chunk sizes\n    auto rchunk_size = device_key.next_addr - roffset;\n    return min(lchunk_size, rchunk_size);\n}\n\n__device__ static __forceinline__ void ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) {\n    auto state = ibgda_get_state();\n    auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);\n\n    uint64_t roffset = addr - heap_start;\n    uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized) +\n        dst_pe * state->num_devices_initialized + dev_idx;\n    nvshmemi_ibgda_device_key_t device_key;\n    if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)\n        device_key = state->constmem.rkeys[idx];\n    else\n        device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];\n    *out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;\n    *out_rkey = device_key.key;\n}\n\n__device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t* qp, uint32_t num_wqes) {\n    auto mvars = &qp->mvars;\n    return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));\n}\n\n__device__ static __forceinline__ void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {\n    uint16_t cnt = qp->tx_wq.nwqes;\n    uint16_t idx = wqe_idx & (cnt - 1);\n    return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));\n}\n\n__device__ static __forceinline__ void nvshmemi_ibgda_rma_p(\n    int* rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {\n    // Get rkey\n    // NOTES: the `p` operation will not cross multiple remote chunks\n    __be32 rkey;\n    uint64_t raddr;\n    auto qp = ibgda_get_rc(dst_pe, qp_id);\n    ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx);\n\n    // Write WQEs\n    uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);\n    void* wqe_ptrs;\n    wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);\n    ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);\n\n    // Submit requests\n    ibgda_submit_requests<true>(qp, base_wqe_idx, 1);\n}\n\n__device__ static __forceinline__ void ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t* qp,\n                                                                  uint64_t laddr,\n                                                                  __be32 lkey,\n                                                                  uint64_t raddr,\n                                                                  __be32 rkey,\n                                                                  uint32_t bytes,\n                                                                  uint16_t wqe_idx,\n                                                                  void** out_wqes) {\n    ibgda_ctrl_seg_t ctrl_seg;\n    struct mlx5_wqe_raddr_seg raddr_seg;\n    struct mlx5_wqe_data_seg data_seg;\n\n    auto* ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);\n    void* av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));\n    struct mlx5_wqe_raddr_seg* raddr_seg_ptr;\n    struct mlx5_wqe_data_seg* data_seg_ptr;\n\n    raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));\n    data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));\n\n    raddr_seg.raddr = HtoBE64(raddr);\n    raddr_seg.rkey = rkey;\n    raddr_seg.reserved = 0;\n\n    data_seg.byte_count = HtoBE32(bytes);\n    data_seg.lkey = lkey;\n    data_seg.addr = HtoBE64(laddr);\n\n    ctrl_seg = {0};\n    ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);\n    ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;\n    ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);\n\n    EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, \"sizeof(*ctrl_seg_ptr) == 16\");\n    EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, \"sizeof(*raddr_seg_ptr) == 16\");\n    EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, \"sizeof(*data_seg_ptr) == 16\");\n    st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));\n    st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));\n    st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));\n}\n\n__device__ static __forceinline__ void ibgda_write_empty_recv_wqe(void* out_wqe) {\n    auto* data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);\n    struct mlx5_wqe_data_seg data_seg;\n\n    // Make the first segment in the WQE invalid, then the entire list will be invalid\n    data_seg.byte_count = 0;\n    data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);\n    data_seg.addr = 0;\n\n    EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), \"Invalid data type length\");\n    st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));\n}\n\ntemplate <bool kAlwaysDoPostSend = false>\n__device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(\n    uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {\n    // Get lkey and rkey, store them into lanes\n    uint32_t num_wqes = 0;\n    __be32 my_lkey = 0;\n    uint64_t my_laddr = 0;\n    __be32 my_rkey = 0;\n    uint64_t my_raddr = 0;\n    uint64_t my_chunk_size = 0;\n\n    auto qp = ibgda_get_rc(dst_pe, qp_id);\n\n    // Decide how many messages (theoretically 3 for maximum)\n    auto remaining_bytes = bytes;\n    while (remaining_bytes > 0) {\n        if (lane_id == num_wqes) {\n            my_chunk_size = min(remaining_bytes,\n                                ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey, qp->dev_idx));\n        }\n\n        // Move one more message\n        auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));\n        remaining_bytes -= chunk_size;\n        req_lptr += chunk_size;\n        req_rptr += chunk_size;\n        ++num_wqes;\n    }\n    EP_DEVICE_ASSERT(num_wqes <= 32);\n\n    // Process WQE\n    uint64_t base_wqe_idx = 0;\n    if (lane_id == 0)\n        base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);\n    base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);\n    if (lane_id < num_wqes) {\n        auto wqe_idx = base_wqe_idx + lane_id;\n        auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);\n        ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size, wqe_idx, &wqe_ptr);\n    }\n    __syncwarp();\n\n    // Submit\n    if (lane_id == 0)\n        ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);\n    __syncwarp();\n}\n\n__device__ static __forceinline__ void ibgda_write_amo_add_wqe(nvshmemi_ibgda_device_qp_t* qp,\n                                                               const int& value,\n                                                               uint64_t laddr,\n                                                               __be32 lkey,\n                                                               uint64_t raddr,\n                                                               __be32 rkey,\n                                                               uint16_t wqe_idx,\n                                                               void** out_wqes) {\n    ibgda_ctrl_seg_t ctrl_seg = {0};\n    struct mlx5_wqe_raddr_seg raddr_seg;\n    struct mlx5_wqe_atomic_seg atomic_seg_1;\n    struct mlx5_wqe_data_seg data_seg;\n\n    auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);\n    auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));\n    auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));\n    auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr));\n\n    raddr_seg.raddr = HtoBE64(raddr);\n    raddr_seg.rkey = rkey;\n    raddr_seg.reserved = 0;\n\n    // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`\n    ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000);\n    auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1);\n    atomic_32_masked_fa_seg->add_data = HtoBE32(value);\n    atomic_32_masked_fa_seg->field_boundary = 0;\n\n    ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4);\n    ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;\n\n    data_seg.byte_count = HtoBE32(sizeof(int));\n    data_seg.lkey = lkey;\n    data_seg.addr = HtoBE64(laddr);\n\n    EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), \"Invalid vectorization\");\n    EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), \"Invalid vectorization\");\n    EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), \"Invalid vectorization\");\n    EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), \"Invalid vectorization\");\n    st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg));\n    st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg));\n    st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1));\n    st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));\n}\n\n__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(\n    void* rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {\n    if (is_local_copy) {\n        atomicAdd(static_cast<unsigned long long*>(rptr), value);\n    } else {\n        nvshmemi_ibgda_device_qp_t* qp = ibgda_get_rc(pe, qp_id);\n\n        __be32 rkey;\n        uint64_t raddr;\n        ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx);\n\n        uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);\n        void* wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);\n\n        ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf), qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);\n\n        ibgda_submit_requests<true>(qp, my_wqe_idx, 1);\n    }\n}\n\n__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) {\n    // Local rank, no need for mapping\n    if (rank == dst_rank)\n        return ptr;\n    auto peer_base = __ldg(reinterpret_cast<uint64_t*>(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank);\n\n    // RDMA connected\n    if (peer_base == 0)\n        return 0;\n\n    // NVLink P2P is enabled\n    return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));\n}\n\n// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.\n// Note that this implementation does not guarantee thread safety,\n// so we must ensure that no other threads are concurrently using the same QP.\n__device__ static __forceinline__ void ibgda_poll_cq(nvshmemi_ibgda_device_cq_t* cq, uint64_t idx) {\n    const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe);\n    const uint32_t ncqes = cq->ncqes;\n    memory_fence_cta();\n    if (*cq->cons_idx >= idx)\n        return;\n    // NOTES: this while loop is part of do-while below.\n    // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`.\n    // To be able to compare with the index, we need to use `wqe_counter + 1`.\n    // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for\n    // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than\n    // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1`\n    // That's why we use `- 2` here to make this case overflow.\n    uint16_t wqe_counter;\n    do {\n        wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter));\n    } while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes));\n    *cq->cons_idx = idx;\n\n    // Prevent reordering of this function and later instructions\n    memory_fence_cta();\n}\n\n// Wait until wqe `idx - 1` is completed.\n__device__ static __forceinline__ void nvshmemi_ibgda_quiet(int dst_pe, int qp_id) {\n    auto qp = ibgda_get_rc(dst_pe, qp_id);\n    auto state = ibgda_get_state();\n    uint64_t prod_idx = state->use_async_postsend ? ld_na_relaxed(qp->tx_wq.prod_idx) : ld_na_relaxed(&qp->mvars.tx_wq.ready_head);\n    ibgda_poll_cq(qp->tx_wq.cq, prod_idx);\n}\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/internode.cu",
    "content": "#include <functional>\n#include <optional>\n\n#include \"buffer.cuh\"\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"ibgda_device.cuh\"\n#include \"launch.cuh\"\n#include \"utils.cuh\"\n\nnamespace deep_ep {\n\nnamespace internode {\n\nextern nvshmem_team_t cpu_rdma_team;\n\nstruct SourceMeta {\n    int src_rdma_rank, is_token_in_nvl_rank_bits;\n\n    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, \"Invalid number of maximum NVL peers\");\n\n    __forceinline__ SourceMeta() = default;\n\n    // TODO: faster encoding\n    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) {\n        src_rdma_rank = rdma_rank;\n        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];\n        #pragma unroll\n        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)\n            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;\n    }\n\n    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; }\n};\n\nEP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, \"Invalid size of `SourceMeta`\");\n\nint get_source_meta_bytes() {\n    return sizeof(SourceMeta);\n}\n\n__host__ __device__ __forceinline__ int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {\n    return static_cast<int>(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) +\n                                         num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float),\n                                     sizeof(int4)));\n}\n\n__host__ __device__ __forceinline__ std::pair<int, int> get_rdma_clean_meta(int hidden_int4,\n                                                                            int num_scales,\n                                                                            int num_topk_idx,\n                                                                            int num_topk_weights,\n                                                                            int num_rdma_ranks,\n                                                                            int num_rdma_recv_buffer_tokens,\n                                                                            int num_channels) {\n    // Return `int32_t` offset and count to clean\n    return {(get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens *\n             num_rdma_ranks * 2 * num_channels) /\n                sizeof(int),\n            (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels};\n}\n\n__host__ __device__ __forceinline__ std::pair<int, int> get_nvl_clean_meta(int hidden_int4,\n                                                                           int num_scales,\n                                                                           int num_topk_idx,\n                                                                           int num_topk_weights,\n                                                                           int num_rdma_ranks,\n                                                                           int num_nvl_ranks,\n                                                                           int num_nvl_recv_buffer_tokens,\n                                                                           int num_channels,\n                                                                           bool is_dispatch) {\n    // Return `int32_t` offset and to clean\n    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, \"Invalid size of `SourceMeta`\");\n\n    return {\n        (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks *\n         num_channels) /\n            sizeof(int),\n        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,\n    };\n}\n\ntemplate <bool kLowLatencyMode>\n__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) {\n    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;\n}\n\ntemplate <bool kLowLatencyMode>\n__forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {\n    kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all();\n}\n\ntemplate <bool kLowLatencyMode, int kNumRDMARanks>\n__global__ void notify_dispatch(const int* num_tokens_per_rank,\n                                int* moe_recv_counter_mapped,\n                                int num_ranks,\n                                const int* num_tokens_per_rdma_rank,\n                                int* moe_recv_rdma_counter_mapped,\n                                const int* num_tokens_per_expert,\n                                int* moe_recv_expert_counter_mapped,\n                                int num_experts,\n                                const bool* is_token_in_rank,\n                                int num_tokens,\n                                int num_worst_tokens,\n                                int num_channels,\n                                int expert_alignment,\n                                const int rdma_clean_offset,\n                                const int rdma_num_int_clean,\n                                const int nvl_clean_offset,\n                                const int nvl_num_int_clean,\n                                int* rdma_channel_prefix_matrix,\n                                int* recv_rdma_rank_prefix_sum,\n                                int* gbl_channel_prefix_matrix,\n                                int* recv_gbl_rank_prefix_sum,\n                                void* rdma_buffer_ptr,\n                                void** buffer_ptrs,\n                                int** barrier_signal_ptrs,\n                                int rank,\n                                const nvshmem_team_t rdma_team) {\n    auto sm_id = static_cast<int>(blockIdx.x);\n    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();\n    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;\n\n    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;\n    auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS;\n\n    if (sm_id == 0) {\n        // Communication with others\n        // Global barrier: the first warp does intra-node sync, the second warp does internode sync\n        EP_DEVICE_ASSERT(num_warps > 1);\n        EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);\n\n        // waiting for all previous inflight wrs to complete,\n        // in case of rewriting cleared rdma_buffer\n        auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;\n        for (int i = thread_id; i < qps_per_rdma_rank * (kNumRDMARanks - 1); i += num_threads) {\n            auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % kNumRDMARanks;\n            auto qp_id = i % qps_per_rdma_rank;\n            nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);\n        }\n        __syncthreads();\n\n        if (thread_id == 32)\n            nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);\n        barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);\n\n        // Send numbers of tokens per rank/expert to RDMA ranks\n        auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);\n        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);\n\n        // Clean up for later data dispatch\n        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int));\n        #pragma unroll\n        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)\n            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;\n\n        // Copy to send buffer\n        #pragma unroll\n        for (int i = thread_id; i < num_ranks; i += num_threads)\n            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i];\n        #pragma unroll\n        for (int i = thread_id; i < num_experts; i += num_threads)\n            rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =\n                num_tokens_per_expert[i];\n        if (thread_id < kNumRDMARanks)\n            rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id];\n        __syncthreads();\n\n        // Issue send\n        // TODO: more light fence or barrier or signaling\n        // TODO: overlap EP barrier and NVL cleaning\n        for (int i = warp_id; i < kNumRDMARanks; i += num_warps) {\n            if (i != rdma_rank) {\n                nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)),\n                                                  reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)),\n                                                  (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int),\n                                                  translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank),\n                                                  0,\n                                                  lane_id,\n                                                  0);\n            } else {\n                UNROLLED_WARP_COPY(1,\n                                   lane_id,\n                                   NUM_MAX_NVL_PEERS + num_rdma_experts + 1,\n                                   rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),\n                                   rdma_recv_num_tokens_mixed.send_buffer(i),\n                                   ld_volatile_global,\n                                   st_na_global);\n            }\n        }\n        __syncthreads();\n\n        // Wait previous operations to be finished\n        if (thread_id < kNumRDMARanks and thread_id != rdma_rank)\n            nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0);\n        __syncthreads();\n\n        // Barrier\n        if (thread_id == 0)\n            nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);\n        __syncthreads();\n\n        // NVL buffers\n        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;\n        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];\n        auto nvl_reduced_num_tokens_per_expert = Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);\n        auto nvl_send_num_tokens_per_rank = AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);\n        auto nvl_send_num_tokens_per_expert = AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);\n        auto nvl_recv_num_tokens_per_rank = AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);\n        auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);\n\n        // Clean up for later data dispatch\n        auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);\n        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +\n                             nvl_send_num_tokens_per_expert.total_bytes <=\n                         nvl_clean_offset * sizeof(int));\n        #pragma unroll\n        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)\n            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;\n\n        // Reduce number of tokens per expert into the NVL send buffer\n        // TODO: may use NVSHMEM reduction\n        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);\n        if (thread_id < num_rdma_experts) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumRDMARanks; ++i)\n                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];\n            nvl_reduced_num_tokens_per_expert[thread_id] = sum;\n        }\n        __syncthreads();\n\n        // Reduce RDMA received tokens\n        if (thread_id == 0) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumRDMARanks; ++i) {\n                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];\n                recv_rdma_rank_prefix_sum[i] = sum;\n            }\n            if (num_worst_tokens == 0) {\n                while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)\n                    ;\n                *moe_recv_rdma_counter_mapped = sum;\n            }\n        }\n\n        // Send numbers of tokens per rank/expert to NVL ranks\n        EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);\n        if (thread_id < NUM_MAX_NVL_PEERS) {\n            #pragma unroll\n            for (int i = 0; i < kNumRDMARanks; ++i)\n                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];\n            #pragma unroll\n            for (int i = 0; i < num_nvl_experts; ++i)\n                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];\n        }\n        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);\n\n        // Reduce the number of tokens per rank/expert\n        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);\n        if (thread_id == 0) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < num_ranks; ++i) {\n                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;\n                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];\n                recv_gbl_rank_prefix_sum[i] = sum;\n            }\n            if (num_worst_tokens == 0) {\n                while (ld_volatile_global(moe_recv_counter_mapped) != -1)\n                    ;\n                *moe_recv_counter_mapped = sum;\n            }\n        }\n        if (thread_id < num_nvl_experts) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)\n                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];\n            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;\n            if (num_worst_tokens == 0) {\n                while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)\n                    ;\n                moe_recv_expert_counter_mapped[thread_id] = sum;\n            }\n        }\n\n        // Finally barrier\n        if (thread_id == 32)\n            nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);\n        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);\n    } else {\n        // Calculate meta data\n        int dst_rdma_rank = sm_id - 1;\n        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {\n            int token_start_idx, token_end_idx;\n            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);\n\n            // Iterate over tokens\n            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};\n            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) {\n                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), \"Invalid number of NVL peers\");\n                auto is_token_in_rank_uint64 =\n                    *reinterpret_cast<const uint64_t*>(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);\n                auto is_token_in_rank_values = reinterpret_cast<const bool*>(&is_token_in_rank_uint64);\n                #pragma unroll\n                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)\n                    per_nvl_rank_count[j] += is_token_in_rank_values[j];\n                total_count += (is_token_in_rank_uint64 != 0);\n            }\n\n            // Warp reduce\n            total_count = warp_reduce_sum(total_count);\n            #pragma unroll\n            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)\n                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);\n\n            // Write into channel matrix\n            if (elect_one_sync()) {\n                #pragma unroll\n                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)\n                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i];\n                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;\n            }\n        }\n\n        // Calculate prefix sum\n        __syncthreads();\n        if (thread_id == 0) {\n            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;\n            #pragma unroll\n            for (int i = 1; i < num_channels; ++i)\n                prefix_row[i] += prefix_row[i - 1];\n        }\n\n        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, \"Invalid number of NVL peers\");\n        if (thread_id < NUM_MAX_NVL_PEERS) {\n            auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;\n            #pragma unroll\n            for (int i = 1; i < num_channels; ++i)\n                prefix_row[i] += prefix_row[i - 1];\n        }\n    }\n}\n\nvoid notify_dispatch(const int* num_tokens_per_rank,\n                     int* moe_recv_counter_mapped,\n                     int num_ranks,\n                     const int* num_tokens_per_rdma_rank,\n                     int* moe_recv_rdma_counter_mapped,\n                     const int* num_tokens_per_expert,\n                     int* moe_recv_expert_counter_mapped,\n                     int num_experts,\n                     const bool* is_token_in_rank,\n                     int num_tokens,\n                     int num_worst_tokens,\n                     int num_channels,\n                     int hidden_int4,\n                     int num_scales,\n                     int num_topk,\n                     int expert_alignment,\n                     int* rdma_channel_prefix_matrix,\n                     int* recv_rdma_rank_prefix_sum,\n                     int* gbl_channel_prefix_matrix,\n                     int* recv_gbl_rank_prefix_sum,\n                     void* rdma_buffer_ptr,\n                     int num_max_rdma_chunked_recv_tokens,\n                     void** buffer_ptrs,\n                     int num_max_nvl_chunked_recv_tokens,\n                     int** barrier_signal_ptrs,\n                     int rank,\n                     cudaStream_t stream,\n                     int64_t num_rdma_bytes,\n                     int64_t num_nvl_bytes,\n                     bool low_latency_mode) {\n#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                                                    \\\n    {                                                                                                                                  \\\n        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks> : notify_dispatch<false, num_rdma_ranks>; \\\n        LAUNCH_KERNEL(&cfg,                                                                                                            \\\n                      notify_dispatch_func,                                                                                            \\\n                      num_tokens_per_rank,                                                                                             \\\n                      moe_recv_counter_mapped,                                                                                         \\\n                      num_ranks,                                                                                                       \\\n                      num_tokens_per_rdma_rank,                                                                                        \\\n                      moe_recv_rdma_counter_mapped,                                                                                    \\\n                      num_tokens_per_expert,                                                                                           \\\n                      moe_recv_expert_counter_mapped,                                                                                  \\\n                      num_experts,                                                                                                     \\\n                      is_token_in_rank,                                                                                                \\\n                      num_tokens,                                                                                                      \\\n                      num_worst_tokens,                                                                                                \\\n                      num_channels,                                                                                                    \\\n                      expert_alignment,                                                                                                \\\n                      rdma_clean_meta.first,                                                                                           \\\n                      rdma_clean_meta.second,                                                                                          \\\n                      nvl_clean_meta.first,                                                                                            \\\n                      nvl_clean_meta.second,                                                                                           \\\n                      rdma_channel_prefix_matrix,                                                                                      \\\n                      recv_rdma_rank_prefix_sum,                                                                                       \\\n                      gbl_channel_prefix_matrix,                                                                                       \\\n                      recv_gbl_rank_prefix_sum,                                                                                        \\\n                      rdma_buffer_ptr,                                                                                                 \\\n                      buffer_ptrs,                                                                                                     \\\n                      barrier_signal_ptrs,                                                                                             \\\n                      rank,                                                                                                            \\\n                      cpu_rdma_team);                                                                                                  \\\n    }                                                                                                                                  \\\n    break\n\n    constexpr int kNumThreads = 512;\n    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;\n\n    // Get clean meta\n    auto rdma_clean_meta =\n        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);\n    auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4,\n                                             num_scales,\n                                             num_topk,\n                                             num_topk,\n                                             num_rdma_ranks,\n                                             NUM_MAX_NVL_PEERS,\n                                             num_max_nvl_chunked_recv_tokens,\n                                             num_channels,\n                                             true);\n    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);\n    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);\n    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());\n\n    // Launch kernel\n    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);\n    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);\n#undef NOTIFY_DISPATCH_LAUNCH_CASE\n}\n\n// At most 8 RDMA ranks to be sent\nconstexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {\n    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;\n}\n\ntemplate <bool kLowLatencyMode,\n          int kNumRDMARanks,\n          bool kCachedMode,\n          int kNumTMABytesPerWarp,\n          int kNumDispatchRDMASenderWarps,\n          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>\n__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)\n    dispatch(int4* recv_x,\n             float* recv_x_scales,\n             topk_idx_t* recv_topk_idx,\n             float* recv_topk_weights,\n             SourceMeta* recv_src_meta,\n             const int4* x,\n             const float* x_scales,\n             const topk_idx_t* topk_idx,\n             const float* topk_weights,\n             int* send_rdma_head,\n             int* send_nvl_head,\n             int* recv_rdma_channel_prefix_matrix,\n             int* recv_gbl_channel_prefix_matrix,\n             const int* rdma_channel_prefix_matrix,\n             const int* recv_rdma_rank_prefix_sum,\n             const int* gbl_channel_prefix_matrix,\n             const int* recv_gbl_rank_prefix_sum,\n             const bool* is_token_in_rank,\n             int num_tokens,\n             int num_worst_tokens,\n             int hidden_int4,\n             int num_scales,\n             int num_topk,\n             int num_experts,\n             int scale_token_stride,\n             int scale_hidden_stride,\n             void* rdma_buffer_ptr,\n             int num_max_rdma_chunked_send_tokens,\n             int num_max_rdma_chunked_recv_tokens,\n             void** buffer_ptrs,\n             int num_max_nvl_chunked_send_tokens,\n             int num_max_nvl_chunked_recv_tokens,\n             int rank,\n             int num_ranks) {\n    enum class WarpRole { kRDMASender, kRDMASenderCoordinator, kRDMAAndNVLForwarder, kForwarderCoordinator, kNVLReceivers };\n\n    const auto num_sms = static_cast<int>(gridDim.x);\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;\n    const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();\n    const auto num_channels = num_sms / 2, channel_id = sm_id / 2;\n    const bool is_forwarder = sm_id % 2 == 0;\n    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;\n\n    EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms);\n\n    const auto role_meta = [=]() -> std::pair<WarpRole, int> {\n        if (is_forwarder) {\n            if (warp_id < NUM_MAX_NVL_PEERS) {\n                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};\n            } else {\n                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};\n            }\n        } else if (warp_id < kNumDispatchRDMASenderWarps) {\n            return {WarpRole::kRDMASender, -1};\n        } else if (warp_id == kNumDispatchRDMASenderWarps) {\n            return {WarpRole::kRDMASenderCoordinator, -1};\n        } else {\n            return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};\n        }\n    }();\n    auto warp_role = role_meta.first;\n    auto target_rank = role_meta.second;  // Not applicable for RDMA senders\n    EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);\n\n    // Data checks\n    EP_DEVICE_ASSERT(num_topk <= 32);\n\n    // RDMA symmetric layout\n    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), \"Invalid number of NVL peers\");\n    auto hidden_bytes = hidden_int4 * sizeof(int4);\n    auto scale_bytes = num_scales * sizeof(float);\n    auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk);\n    auto rdma_channel_data = SymBuffer<uint8_t>(\n        rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);\n    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);\n    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);\n    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);\n\n    // NVL buffer layouts\n    // NOTES: `rs_wr_buffer_ptr` means \"Read for Senders, Write for Receivers\", `ws_rr_buffer_ptr` means \"Write for Senders, Read for\n    // Receivers\"\n    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;\n    int rs_wr_rank = 0, ws_rr_rank = 0;\n    if (warp_role == WarpRole::kRDMAAndNVLForwarder)\n        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank,\n        ws_rr_rank = target_rank;\n    if (warp_role == WarpRole::kNVLReceivers)\n        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank,\n        ws_rr_rank = nvl_rank;\n\n    // Allocate buffers\n    auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr,\n                                             num_max_nvl_chunked_recv_tokens * num_bytes_per_token,\n                                             NUM_MAX_NVL_PEERS,\n                                             channel_id,\n                                             num_channels,\n                                             rs_wr_rank)\n                             .advance_also(rs_wr_buffer_ptr);\n    auto nvl_channel_prefix_start =\n        AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)\n            .advance_also(rs_wr_buffer_ptr);\n    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)\n                                      .advance_also(rs_wr_buffer_ptr);\n    auto nvl_channel_head =\n        AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);\n    auto nvl_channel_tail =\n        AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);\n\n    // RDMA sender warp synchronization\n    // NOTES: `rdma_send_channel_tail` means the latest released tail\n    // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status\n    __shared__ int rdma_send_channel_lock[kNumRDMARanks];\n    __shared__ int rdma_send_channel_tail[kNumRDMARanks];\n    __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];\n    auto sync_rdma_sender_smem = []() { asm volatile(\"barrier.sync 0, %0;\" ::\"r\"((kNumDispatchRDMASenderWarps + 1) * 32)); };\n\n    // TMA stuffs\n    extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];\n    auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp;\n    auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);\n    uint32_t tma_phase = 0;\n    if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and elect_one_sync()) {\n        mbarrier_init(tma_mbarrier, 1);\n        fence_barrier_init();\n        EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);\n    }\n    __syncwarp();\n\n    // Forward warp synchronization\n    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];\n    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];\n    auto sync_forwarder_smem = []() { asm volatile(\"barrier.sync 1, %0;\" ::\"r\"((NUM_MAX_NVL_PEERS + 1) * 32)); };\n\n    if (warp_role == WarpRole::kRDMASender) {\n        // Get tasks\n        int token_start_idx, token_end_idx;\n        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);\n\n        // Send number of tokens in this channel by `-value - 1`\n        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, \"Invalid number of NVL peers\");\n        for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {\n            auto dst_ptr =\n                dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);\n            if (lane_id < NUM_MAX_NVL_PEERS) {\n                dst_ptr[lane_id] =\n                    -(channel_id == 0\n                          ? 0\n                          : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) -\n                    1;\n            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {\n                dst_ptr[lane_id] =\n                    -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels +\n                                               channel_id] -\n                    1;\n            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {\n                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;\n            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {\n                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;\n            }\n            __syncwarp();\n\n            // Issue RDMA for non-local ranks\n            if (dst_rdma_rank != rdma_rank) {\n                nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),\n                                                  reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),\n                                                  sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),\n                                                  translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                  channel_id,\n                                                  lane_id,\n                                                  0);\n            }\n        }\n        sync_rdma_sender_smem();\n\n        // Iterate over tokens and copy into buffer\n        int64_t token_idx;\n        int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0;\n        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);\n        for (token_idx = token_start_idx; token_idx < token_end_idx; ++token_idx) {\n            // Read RDMA rank existence\n            uint64_t is_token_in_rank_uint64 = 0;\n            if (lane_id < kNumRDMARanks) {\n                is_token_in_rank_uint64 =\n                    __ldg(reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS));\n                global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);\n            }\n            __syncwarp();\n\n            // Skip the token which does not belong to this warp\n            if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id)\n                continue;\n            auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;\n\n            // Wait the remote buffer to be released\n            auto start_time = clock64();\n            while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {\n                cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));\n\n                // Timeout check\n                if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {\n                    printf(\"DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\\n\",\n                           channel_id,\n                           rdma_rank,\n                           nvl_rank,\n                           lane_id,\n                           cached_rdma_channel_head,\n                           rdma_tail_idx);\n                    trap();\n                }\n            }\n            __syncwarp();\n\n            // Store RDMA head for combine\n            if (lane_id < kNumRDMARanks and not kCachedMode)\n                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;\n\n            // Broadcast tails\n            SourceMeta src_meta;\n            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];\n            void* dst_send_buffers[kNumTopkRDMARanks];\n            #pragma unroll\n            for (int i = 0, slot_idx; i < kNumRDMARanks; ++i)\n                if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) {\n                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;\n                    topk_ranks[num_topk_ranks] = i;\n                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);\n                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);\n                    if (lane_id == num_topk_ranks)\n                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);\n                    dst_send_buffers[num_topk_ranks++] =\n                        reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_token;\n                }\n            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);\n\n            // Copy `x` into symmetric send buffer\n            auto st_broadcast = [=](const int key, const int4& value) {\n                #pragma unroll\n                for (int j = 0; j < num_topk_ranks; ++j)\n                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);\n            };\n            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);\n            #pragma unroll\n            for (int i = 0; i < num_topk_ranks; ++i)\n                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;\n\n            // Copy `x_scales` into symmetric send buffer\n            #pragma unroll\n            for (int i = lane_id; i < num_scales; i += 32) {\n                auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;\n                auto value = ld_nc_global(x_scales + offset);\n                #pragma unroll\n                for (int j = 0; j < num_topk_ranks; ++j)\n                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);\n            }\n            #pragma unroll\n            for (int i = 0; i < num_topk_ranks; ++i)\n                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;\n\n            // Copy source metadata into symmetric send buffer\n            if (lane_id < num_topk_ranks)\n                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);\n            #pragma unroll\n            for (int i = 0; i < num_topk_ranks; ++i)\n                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;\n\n            // Copy `topk_idx` and `topk_weights` into symmetric send buffer\n            #pragma unroll\n            for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) {\n                auto rank_idx = i / num_topk, copy_idx = i % num_topk;\n                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));\n                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);\n                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);\n                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);\n            }\n            __syncwarp();\n\n            // Release the transaction in the window\n            if (is_token_in_rank_uint64 != 0) {\n                // Acquire lock first\n                acquire_lock(rdma_send_channel_lock + lane_id);\n                auto latest_tail = rdma_send_channel_tail[lane_id];\n                auto offset = rdma_tail_idx - latest_tail;\n                while (offset >= 32) {\n                    release_lock(rdma_send_channel_lock + lane_id);\n                    acquire_lock(rdma_send_channel_lock + lane_id);\n                    latest_tail = rdma_send_channel_tail[lane_id];\n                    offset = rdma_tail_idx - latest_tail;\n                }\n\n                // Release the transaction slot\n                // Add the bit and move the ones if possible\n                auto window = rdma_send_channel_window[lane_id] | (1u << offset);\n                if (offset == 0) {\n                    auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1;\n                    st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);\n                    window >>= num_empty_slots;\n                }\n                rdma_send_channel_window[lane_id] = window;\n\n                // Release lock\n                release_lock(rdma_send_channel_lock + lane_id);\n            }\n            __syncwarp();\n        }\n    } else if (warp_role == WarpRole::kRDMASenderCoordinator) {\n        // NOTES: in case of splitting, the issued put at the end of the buffer\n        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);\n\n        // Clean shared memory\n        EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA ranks\");\n        (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;\n        (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;\n        (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;\n\n        // Synchronize shared memory\n        sync_rdma_sender_smem();\n\n        // Get number of tokens to send for each RDMA rank\n        int num_tokens_to_send = 0;\n        if (lane_id < kNumRDMARanks) {\n            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];\n            if (channel_id > 0)\n                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];\n        }\n\n        // Iterate all RDMA ranks\n        int last_issued_tail = 0;\n        auto start_time = clock64();\n        while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {\n            // Timeout check\n            if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {\n                printf(\"DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\\n\",\n                       channel_id,\n                       rdma_rank,\n                       nvl_rank,\n                       lane_id,\n                       last_issued_tail,\n                       num_tokens_to_send);\n                trap();\n            }\n\n            // TODO: try thread-level `put_nbi`?\n            for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {\n                // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels\n                int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;\n                synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank);\n                if (synced_num_tokens_to_send == 0)\n                    continue;\n\n                // Read the latest progress\n                // NOTES: `rdma_send_channel_tail` does not need to be protected by lock\n                auto processed_tail =\n                    __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);\n                auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);\n                auto num_tokens_processed = processed_tail - synced_last_issued_tail;\n                if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)\n                    continue;\n\n                // Issue RDMA send\n                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);\n                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send);\n                if (dst_rdma_rank != rdma_rank) {\n                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;\n                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);\n                    const size_t num_bytes_per_msg = num_bytes_per_token * num_tokens_to_issue;\n                    const auto dst_ptr =\n                        reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token);\n                    const auto src_ptr =\n                        reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token);\n                    nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr,\n                                                      src_ptr,\n                                                      num_bytes_per_msg,\n                                                      translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                      channel_id,\n                                                      lane_id,\n                                                      0);\n                } else {\n                    // Lighter fence for local RDMA rank\n                    memory_fence();\n                }\n                __syncwarp();\n\n                // Update tails\n                if (lane_id == dst_rdma_rank) {\n                    last_issued_tail += num_tokens_to_issue;\n                    num_tokens_to_send -= num_tokens_to_issue;\n                    nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank),\n                                                    num_tokens_to_issue,\n                                                    translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                    channel_id,\n                                                    dst_rdma_rank == rdma_rank);\n                }\n                __syncwarp();\n            }\n        }\n    } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {\n        // RDMA consumers and NVL producers\n        const auto dst_nvl_rank = target_rank;\n\n        // Wait counters to arrive\n        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;\n        EP_DEVICE_ASSERT(kNumRDMARanks <= 32);\n        auto start_time = clock64();\n        if (lane_id < kNumRDMARanks) {\n            while (true) {\n                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);\n                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank);\n                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);\n                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);\n                if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {\n                    // Notify NVL ranks\n                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;\n                    EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);\n                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);\n                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);\n\n                    // Save RDMA channel received token count\n                    src_rdma_channel_prefix = -meta_2 - 1;\n                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;\n                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix;\n                    if (not kCachedMode)\n                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;\n                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];\n                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);\n                    break;\n                }\n\n                // Timeout check\n                if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                    printf(\n                        \"DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, \"\n                        \"meta: %d, %d, %d, %d\\n\",\n                        channel_id,\n                        rdma_rank,\n                        nvl_rank,\n                        lane_id,\n                        dst_nvl_rank,\n                        meta_0,\n                        meta_1,\n                        meta_2,\n                        meta_3);\n                    trap();\n                }\n            }\n        }\n        __syncwarp();\n\n        // Shift cached head\n        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;\n\n        // Wait shared memory to be cleaned\n        sync_forwarder_smem();\n\n        // Forward tokens from RDMA buffer\n        // NOTES: always start from the local rank\n        int src_rdma_rank = sm_id % kNumRDMARanks;\n        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;\n        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;\n        while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {\n            // Check destination queue emptiness, or wait a buffer to be released\n            start_time = clock64();\n            while (true) {\n                const int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;\n                if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)\n                    break;\n                cached_nvl_channel_head = __shfl_sync(0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0);\n\n                // Timeout check\n                if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                    printf(\n                        \"DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\\n\",\n                        channel_id,\n                        rdma_rank,\n                        nvl_rank,\n                        dst_nvl_rank,\n                        ld_volatile_global(nvl_channel_head.buffer()),\n                        cached_nvl_channel_tail);\n                    trap();\n                }\n            }\n\n            // Find next source RDMA rank (round-robin)\n            start_time = clock64();\n            while (true) {\n                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;\n                if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {\n                    if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)\n                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));\n                    if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank))\n                        break;\n                }\n\n                // Timeout check\n                if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {\n                    printf(\n                        \"DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, \"\n                        \"head: %d, tail: %d, expected: %d\\n\",\n                        channel_id,\n                        rdma_rank,\n                        nvl_rank,\n                        dst_nvl_rank,\n                        lane_id,\n                        cached_rdma_channel_head,\n                        cached_rdma_channel_tail,\n                        num_tokens_to_recv_from_rdma);\n                    trap();\n                }\n            }\n            auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);\n            auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);\n\n            // Iterate over every token from the RDMA buffer\n            for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {\n                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;\n                auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_token;\n                auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));\n                lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;\n                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);\n                if (lane_id == src_rdma_rank) {\n                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;\n                    rdma_nvl_token_idx += is_in_dst_nvl_rank;\n                    if (not kCachedMode)\n                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;\n                }\n                if (not is_in_dst_nvl_rank)\n                    continue;\n\n                // Get an empty slot\n                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;\n                auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;\n\n                // Copy data\n                if (elect_one_sync()) {\n                    tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false);\n                    mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token);\n                }\n                __syncwarp();\n                mbarrier_wait(tma_mbarrier, tma_phase);\n                if (elect_one_sync())\n                    tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token);\n                __syncwarp();\n\n                // In case of insufficient NVL buffers, early stopping\n                if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)\n                    src_rdma_tail = i + 1;\n\n                // Wait TMA to be finished\n                tma_store_wait<0>();\n                __syncwarp();\n            }\n\n            // Sync head index\n            if (lane_id == src_rdma_rank)\n                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);\n\n            // Move tail index\n            __syncwarp();\n            if (elect_one_sync())\n                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);\n        }\n\n        // Retired\n        __syncwarp();\n        if (elect_one_sync())\n            forward_channel_retired[dst_nvl_rank] = true;\n    } else if (warp_role == WarpRole::kForwarderCoordinator) {\n        // Extra warps for forwarder coordinator should exit directly\n        if (target_rank > 0)\n            return;\n\n        // Forward warp coordinator\n        EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA peers\");\n\n        // Clean shared memory\n        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, \"Invalid number of NVL peers\");\n        #pragma unroll\n        for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32)\n            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;\n        if (lane_id < NUM_MAX_NVL_PEERS)\n            forward_channel_retired[lane_id] = false;\n        sync_forwarder_smem();\n\n        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;\n        while (true) {\n            // Find minimum head\n            int min_head = std::numeric_limits<int>::max();\n            #pragma unroll\n            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)\n                if (not forward_channel_retired[i])\n                    min_head = min(min_head, forward_channel_head[i][target_rdma]);\n            if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max()))\n                break;\n\n            // Update remote head\n            if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and\n                lane_id < kNumRDMARanks) {\n                nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank),\n                                                min_head - last_head,\n                                                translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank),\n                                                channel_id + num_channels,\n                                                lane_id == rdma_rank);\n                last_head = min_head;\n            }\n\n            // Nanosleep and let other warps work\n            __nanosleep(NUM_WAIT_NANOSECONDS);\n        }\n    } else {\n        // NVL consumers\n        // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)\n        int src_nvl_rank = target_rank, total_offset = 0;\n        const int local_expert_begin = rank * (num_experts / num_ranks);\n        const int local_expert_end = local_expert_begin + (num_experts / num_ranks);\n\n        EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA peers\");\n        if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)\n            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];\n\n        // Receive channel offsets\n        int start_offset = 0, end_offset = 0, num_tokens_to_recv;\n        auto start_time = clock64();\n        while (lane_id < kNumRDMARanks) {\n            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);\n            end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);\n            if (start_offset < 0 and end_offset < 0) {\n                start_offset = -start_offset - 1, end_offset = -end_offset - 1;\n                total_offset += start_offset;\n                break;\n            }\n\n            // Timeout check\n            if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                printf(\n                    \"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\\n\",\n                    channel_id,\n                    rdma_rank,\n                    nvl_rank,\n                    lane_id,\n                    src_nvl_rank,\n                    start_offset,\n                    end_offset);\n                trap();\n            }\n        }\n        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);\n\n        // Save for combine usage\n        if (lane_id < kNumRDMARanks and not kCachedMode)\n            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;\n        __syncwarp();\n\n        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;\n        while (num_tokens_to_recv > 0) {\n            // Check channel status by lane 0\n            start_time = clock64();\n            while (true) {\n                // Ready to copy\n                if (cached_channel_head_idx != cached_channel_tail_idx)\n                    break;\n                cached_channel_tail_idx = __shfl_sync(0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0);\n\n                // Timeout check\n                if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                    printf(\"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\\n\",\n                           channel_id,\n                           rdma_rank,\n                           nvl_rank,\n                           src_nvl_rank,\n                           cached_channel_head_idx,\n                           cached_channel_tail_idx);\n                    trap();\n                }\n            }\n\n            // Copy data\n            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;\n            for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {\n                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;\n                auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token;\n                auto meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));\n                int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);\n                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;\n\n                bool scale_aligned = (scale_bytes % 16 == 0);\n                auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0);\n\n                // Copy data\n                if (elect_one_sync()) {\n                    tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);\n                    mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes);\n                }\n                __syncwarp();\n                mbarrier_wait(tma_mbarrier, tma_phase);\n                if (elect_one_sync()) {\n                    tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false);\n                    if (scale_aligned)\n                        tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false);\n                }\n                __syncwarp();\n                shifted += hidden_bytes;\n\n                // Copy scales\n                // TODO: make it as templated\n                if (not scale_aligned) {\n                    UNROLLED_WARP_COPY(1,\n                                       lane_id,\n                                       num_scales,\n                                       recv_x_scales + recv_token_idx * num_scales,\n                                       reinterpret_cast<float*>(shifted),\n                                       ld_nc_global,\n                                       st_na_global);\n                }\n                shifted += scale_bytes;\n\n                // Copy source meta\n                if (not kCachedMode and elect_one_sync())\n                    st_na_global(recv_src_meta + recv_token_idx, meta);\n                shifted += sizeof(SourceMeta);\n\n                // Copy `topk_idx` and `topk_weights`\n                if (lane_id < num_topk) {\n                    // Read\n                    auto idx_value = static_cast<topk_idx_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id));\n                    auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted + sizeof(int) * num_topk) + lane_id);\n                    auto recv_idx = recv_token_idx * num_topk + lane_id;\n\n                    // Transform and write\n                    idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1;\n                    weight_value = idx_value >= 0 ? weight_value : 0.0f;\n                    st_na_global(recv_topk_idx + recv_idx, idx_value);\n                    st_na_global(recv_topk_weights + recv_idx, weight_value);\n                }\n\n                // Wait TMA to be finished\n                tma_store_wait<0>();\n                __syncwarp();\n            }\n\n            // Move queue\n            if (elect_one_sync())\n                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);\n        }\n    }\n\n    // Clean unused `recv_topk_idx` as -1\n    if (num_worst_tokens > 0) {\n        if (is_forwarder)\n            return;\n        // get the actual number of num_recv_tokens on the current rank\n        int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1];\n        // some ForwarderCoordinator threads exit early, so we only use non-forwarder in clean-up\n        // channel_id * num_threads is the offset of the current non-forwarder sms\n        const auto clean_start = num_recv_tokens * num_topk + channel_id * num_threads;\n        const auto clean_end = num_worst_tokens * num_topk;\n        const auto clean_stride = num_channels * num_threads;\n        #pragma unroll\n        for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)\n            recv_topk_idx[i] = -1;\n    }\n}\n\nvoid dispatch(void* recv_x,\n              float* recv_x_scales,\n              topk_idx_t* recv_topk_idx,\n              float* recv_topk_weights,\n              void* recv_src_meta,\n              const void* x,\n              const float* x_scales,\n              const topk_idx_t* topk_idx,\n              const float* topk_weights,\n              int* send_rdma_head,\n              int* send_nvl_head,\n              int* recv_rdma_channel_prefix_matrix,\n              int* recv_gbl_channel_prefix_matrix,\n              const int* rdma_channel_prefix_matrix,\n              const int* recv_rdma_rank_prefix_sum,\n              const int* gbl_channel_prefix_matrix,\n              const int* recv_gbl_rank_prefix_sum,\n              const bool* is_token_in_rank,\n              int num_tokens,\n              int num_worst_tokens,\n              int hidden_int4,\n              int num_scales,\n              int num_topk,\n              int num_experts,\n              int scale_token_stride,\n              int scale_hidden_stride,\n              void* rdma_buffer_ptr,\n              int num_max_rdma_chunked_send_tokens,\n              int num_max_rdma_chunked_recv_tokens,\n              void** buffer_ptrs,\n              int num_max_nvl_chunked_send_tokens,\n              int num_max_nvl_chunked_recv_tokens,\n              int rank,\n              int num_ranks,\n              bool is_cached_dispatch,\n              cudaStream_t stream,\n              int num_channels,\n              bool low_latency_mode) {\n    constexpr int kNumDispatchRDMASenderWarps = 7;\n    constexpr int kNumTMABytesPerWarp = 16384;\n    constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS;\n\n    // Make sure never OOB\n    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());\n\n#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                                                   \\\n    {                                                                                                                          \\\n        auto dispatch_func = low_latency_mode                                                                                  \\\n            ? (is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>     \\\n                                  : dispatch<true, num_rdma_ranks, false, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>)   \\\n            : (is_cached_dispatch ? dispatch<false, num_rdma_ranks, true, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>    \\\n                                  : dispatch<false, num_rdma_ranks, false, kNumTMABytesPerWarp, kNumDispatchRDMASenderWarps>); \\\n        SET_SHARED_MEMORY_FOR_TMA(dispatch_func);                                                                              \\\n        LAUNCH_KERNEL(&cfg,                                                                                                    \\\n                      dispatch_func,                                                                                           \\\n                      reinterpret_cast<int4*>(recv_x),                                                                         \\\n                      recv_x_scales,                                                                                           \\\n                      recv_topk_idx,                                                                                           \\\n                      recv_topk_weights,                                                                                       \\\n                      reinterpret_cast<SourceMeta*>(recv_src_meta),                                                            \\\n                      reinterpret_cast<const int4*>(x),                                                                        \\\n                      x_scales,                                                                                                \\\n                      topk_idx,                                                                                                \\\n                      topk_weights,                                                                                            \\\n                      send_rdma_head,                                                                                          \\\n                      send_nvl_head,                                                                                           \\\n                      recv_rdma_channel_prefix_matrix,                                                                         \\\n                      recv_gbl_channel_prefix_matrix,                                                                          \\\n                      rdma_channel_prefix_matrix,                                                                              \\\n                      recv_rdma_rank_prefix_sum,                                                                               \\\n                      gbl_channel_prefix_matrix,                                                                               \\\n                      recv_gbl_rank_prefix_sum,                                                                                \\\n                      is_token_in_rank,                                                                                        \\\n                      num_tokens,                                                                                              \\\n                      num_worst_tokens,                                                                                        \\\n                      hidden_int4,                                                                                             \\\n                      num_scales,                                                                                              \\\n                      num_topk,                                                                                                \\\n                      num_experts,                                                                                             \\\n                      scale_token_stride,                                                                                      \\\n                      scale_hidden_stride,                                                                                     \\\n                      rdma_buffer_ptr,                                                                                         \\\n                      num_max_rdma_chunked_send_tokens,                                                                        \\\n                      num_max_rdma_chunked_recv_tokens,                                                                        \\\n                      buffer_ptrs,                                                                                             \\\n                      num_max_nvl_chunked_send_tokens,                                                                         \\\n                      num_max_nvl_chunked_recv_tokens,                                                                         \\\n                      rank,                                                                                                    \\\n                      num_ranks);                                                                                              \\\n    }                                                                                                                          \\\n    break\n\n    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));\n    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));\n\n    SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream);\n    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);\n#undef DISPATCH_LAUNCH_CASE\n}\n\ntemplate <bool kLowLatencyMode, int kNumTMABytesPerWarp>\n__global__ void cached_notify(const int rdma_clean_offset,\n                              const int rdma_num_int_clean,\n                              const int nvl_clean_offset,\n                              const int nvl_num_int_clean,\n                              int* combined_rdma_head,\n                              int num_combined_tokens,\n                              int num_channels,\n                              const int* rdma_channel_prefix_matrix,\n                              const int* rdma_rank_prefix_sum,\n                              int* combined_nvl_head,\n                              void* rdma_buffer_ptr,\n                              void** buffer_ptrs,\n                              int** barrier_signal_ptrs,\n                              int rank,\n                              int num_ranks,\n                              bool is_cached_dispatch,\n                              const nvshmem_team_t rdma_team) {\n    auto sm_id = static_cast<int>(blockIdx.x);\n    auto thread_id = static_cast<int>(threadIdx.x);\n    auto num_threads = static_cast<int>(blockDim.x);\n    auto num_warps = num_threads / 32;\n    auto warp_id = thread_id / 32;\n    auto lane_id = get_lane_id();\n\n    auto nvl_rank = rank % NUM_MAX_NVL_PEERS;\n    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;\n    auto rdma_rank = rank / NUM_MAX_NVL_PEERS;\n\n    // Using two SMs, which clean the RDMA/NVL buffer respectively\n    if (sm_id == 0) {\n        auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;\n        for (int i = thread_id; i < qps_per_rdma_rank * (num_rdma_ranks - 1); i += num_threads) {\n            auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % num_rdma_ranks;\n            auto qp_id = i % qps_per_rdma_rank;\n            nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);\n        }\n        __syncthreads();\n\n        // Barrier for RDMA\n        if (thread_id == 32)\n            nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);\n\n        // Barrier for NVL\n        barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);\n\n        // Clean RDMA buffer\n        auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);\n        #pragma unroll\n        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)\n            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;\n\n        // Clean NVL buffer\n        auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);\n        #pragma unroll\n        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)\n            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;\n        __syncthreads();\n\n        // Barrier again\n        if (thread_id == 32)\n            nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);\n        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);\n    } else if (sm_id == 1) {\n        if (is_cached_dispatch)\n            return;\n\n        EP_DEVICE_ASSERT(num_warps >= num_channels);\n        EP_DEVICE_ASSERT(num_rdma_ranks <= 32);\n\n        // Iterate in reverse order\n        if (lane_id < num_rdma_ranks and warp_id < num_channels) {\n            int token_start_idx, token_end_idx;\n            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx);\n\n            // NOTES: `1 << 25` is a heuristic large number\n            int last_head = 1 << 25;\n            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {\n                auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);\n                if (current_head < 0) {\n                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;\n                } else {\n                    last_head = current_head;\n                }\n            }\n        }\n    } else {\n        if (is_cached_dispatch)\n            return;\n\n        EP_DEVICE_ASSERT(num_warps >= num_channels);\n        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);\n        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, \"Too many NVL peers\");\n\n        if (warp_id < num_channels) {\n            constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t);\n            constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS;\n            constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token;\n            EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, \"num_bytes_per_token should be divisible by 16\");\n\n            // TMA stuffs\n            extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];\n            auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;\n            auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + tma_batch_size);\n            uint32_t tma_phase = 0;\n            if (elect_one_sync()) {\n                mbarrier_init(tma_mbarrier, 1);\n                fence_barrier_init();\n            }\n            __syncwarp();\n\n            for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) {\n                // Iterate in reverse order\n                int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];\n                int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];\n                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];\n                token_start_idx += shift, token_end_idx += shift;\n\n                // NOTES: `1 << 25` is a heuristic large number\n                int last_head = 1 << 25;\n                for (int batch_end_idx = token_end_idx; batch_end_idx > token_start_idx; batch_end_idx -= num_tokens_per_batch) {\n                    auto batch_start_idx = max(token_start_idx, batch_end_idx - num_tokens_per_batch);\n\n                    if (elect_one_sync()) {\n                        tma_load_1d(tma_buffer,\n                                    combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS,\n                                    tma_mbarrier,\n                                    (batch_end_idx - batch_start_idx) * num_bytes_per_token);\n                        mbarrier_arrive_and_expect_tx(tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token);\n                    }\n                    mbarrier_wait(tma_mbarrier, tma_phase);\n                    __syncwarp();\n\n                    for (int token_idx = batch_end_idx - 1; token_idx >= batch_start_idx; --token_idx) {\n                        if (lane_id < NUM_MAX_NVL_PEERS) {\n                            auto current_head =\n                                reinterpret_cast<int*>(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id];\n                            if (current_head < 0) {\n                                reinterpret_cast<int*>(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id] =\n                                    -last_head - 1;\n                            } else {\n                                last_head = current_head;\n                            }\n                        }\n                    }\n                    tma_store_fence();\n                    __syncwarp();\n\n                    if (elect_one_sync())\n                        tma_store_1d(tma_buffer,\n                                     combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS,\n                                     (batch_end_idx - batch_start_idx) * num_bytes_per_token);\n                    tma_store_wait<0>();\n                    __syncwarp();\n                }\n            }\n        }\n    }\n}\n\nvoid cached_notify(int hidden_int4,\n                   int num_scales,\n                   int num_topk_idx,\n                   int num_topk_weights,\n                   int num_ranks,\n                   int num_channels,\n                   int num_combined_tokens,\n                   int* combined_rdma_head,\n                   const int* rdma_channel_prefix_matrix,\n                   const int* rdma_rank_prefix_sum,\n                   int* combined_nvl_head,\n                   void* rdma_buffer_ptr,\n                   int num_max_rdma_chunked_recv_tokens,\n                   void** buffer_ptrs,\n                   int num_max_nvl_chunked_recv_tokens,\n                   int** barrier_signal_ptrs,\n                   int rank,\n                   cudaStream_t stream,\n                   int64_t num_rdma_bytes,\n                   int64_t num_nvl_bytes,\n                   bool is_cached_dispatch,\n                   bool low_latency_mode) {\n    const int num_threads = std::max(128, 32 * num_channels);\n    const int num_warps = num_threads / 32;\n    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;\n    const int kNumTMABytesPerWarp = 8192;\n    const int smem_size = kNumTMABytesPerWarp * num_warps;\n\n    // Get clean meta\n    auto rdma_clean_meta = get_rdma_clean_meta(\n        hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);\n    auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4,\n                                             num_scales,\n                                             num_topk_idx,\n                                             num_topk_weights,\n                                             num_rdma_ranks,\n                                             NUM_MAX_NVL_PEERS,\n                                             num_max_nvl_chunked_recv_tokens,\n                                             num_channels,\n                                             is_cached_dispatch);\n    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);\n    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);\n    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());\n    EP_HOST_ASSERT(num_channels * 2 > 3);\n\n    // Launch kernel\n    auto cached_notify_func = low_latency_mode ? cached_notify<true, kNumTMABytesPerWarp> : cached_notify<false, kNumTMABytesPerWarp>;\n    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);\n    SET_SHARED_MEMORY_FOR_TMA(cached_notify_func);\n    LAUNCH_KERNEL(&cfg,\n                  cached_notify_func,\n                  rdma_clean_meta.first,\n                  rdma_clean_meta.second,\n                  nvl_clean_meta.first,\n                  nvl_clean_meta.second,\n                  combined_rdma_head,\n                  num_combined_tokens,\n                  num_channels,\n                  rdma_channel_prefix_matrix,\n                  rdma_rank_prefix_sum,\n                  combined_nvl_head,\n                  rdma_buffer_ptr,\n                  buffer_ptrs,\n                  barrier_signal_ptrs,\n                  rank,\n                  num_ranks,\n                  is_cached_dispatch,\n                  cpu_rdma_team);\n}\n\ntemplate <int kNumRanks,\n          bool kMaybeWithBias,\n          typename dtype_t,\n          int kMaxNumRanks,\n          bool kUseTMA,\n          int kNumStages,\n          int kNumTMALoadBytes = 0,\n          typename GetAddrFn,\n          typename ReceiveTWFn>\n__device__ int combine_token(bool is_token_in_rank,\n                             int head_idx,\n                             int lane_id,\n                             int hidden_int4,\n                             int num_topk,\n                             int4* combined_row,\n                             float* combined_topk_weights,\n                             const int4* bias_0_int4,\n                             const int4* bias_1_int4,\n                             int num_max_recv_tokens,\n                             const GetAddrFn& get_addr_fn,\n                             const ReceiveTWFn& recv_tw_fn,\n                             uint8_t* smem_ptr,\n                             uint32_t (&tma_phase)[kNumStages]) {\n    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);\n\n    // Broadcast current heads\n    // Lane `i` holds the head of rank `i` and `is_token_in_rank`\n    EP_STATIC_ASSERT(kMaxNumRanks <= 32, \"Too many ranks\");\n    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];\n    #pragma unroll\n    for (int i = 0; i < kNumRanks; ++i)\n        if (__shfl_sync(0xffffffff, is_token_in_rank, i)) {\n            slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens;\n            topk_ranks[num_topk_ranks++] = i;\n        }\n    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);\n    EP_STATIC_ASSERT(not(kUseTMA and kMaybeWithBias), \"TMA cannot be used by receiver warps\");\n    EP_STATIC_ASSERT(kNumStages == 2, \"Only support 2 stages now\");\n\n    // Reduce data\n    if constexpr (kUseTMA) {\n        constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16;\n        EP_DEVICE_ASSERT(hidden_int4 % 32 == 0);\n\n        auto tma_load_buffer = [=](const int& i, const int& j) -> int4* {\n            return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + j * kNumTMALoadBytes);\n        };\n        auto tma_store_buffer = [=](const int& i) -> int4* {\n            return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + NUM_MAX_NVL_PEERS * kNumTMALoadBytes);\n        };\n        auto tma_mbarrier = [=](const int& i) -> uint64_t* {\n            return reinterpret_cast<uint64_t*>(smem_ptr + i * kNumTMABufferBytesPerStage + (NUM_MAX_NVL_PEERS + 1) * kNumTMALoadBytes);\n        };\n\n        // Prefetch\n        if (lane_id < num_topk_ranks)\n            tma_load_1d(\n                tma_load_buffer(0, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], 0), tma_mbarrier(0), kNumTMALoadBytes);\n        mbarrier_arrive_and_expect_tx(tma_mbarrier(0), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0);\n        __syncwarp();\n\n        for (int shifted = 0, iter = 0; shifted < hidden_int4; shifted += 32, iter += 1) {\n            const int stage_idx = iter % kNumStages;\n            const int next_stage_idx = (iter + 1) % kNumStages;\n\n            // Prefetch next stage\n            if (shifted + 32 < hidden_int4) {\n                if (lane_id < num_topk_ranks)\n                    tma_load_1d(tma_load_buffer(next_stage_idx, lane_id),\n                                get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], shifted + 32),\n                                tma_mbarrier(next_stage_idx),\n                                kNumTMALoadBytes);\n                mbarrier_arrive_and_expect_tx(tma_mbarrier(next_stage_idx), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0);\n                __syncwarp();\n            }\n\n            mbarrier_wait(tma_mbarrier(stage_idx), tma_phase[stage_idx]);\n            float values[kDtypePerInt4] = {0};\n            #pragma unroll\n            for (int j = 0; j < num_topk_ranks; ++j) {\n                auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(tma_load_buffer(stage_idx, j) + lane_id);\n                #pragma unroll\n                for (int k = 0; k < kDtypePerInt4; ++k)\n                    values[k] += static_cast<float>(recv_value_dtypes[k]);\n            }\n\n            // Wait shared memory to be released\n            tma_store_wait<kNumStages - 1>();\n\n            // Copy into shared and issue TMA\n            auto out_dtypes = reinterpret_cast<dtype_t*>(tma_store_buffer(stage_idx) + lane_id);\n            #pragma unroll\n            for (int j = 0; j < kDtypePerInt4; ++j)\n                out_dtypes[j] = static_cast<dtype_t>(values[j]);\n            tma_store_fence();\n            __syncwarp();\n\n            if (elect_one_sync())\n                tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted, kNumTMALoadBytes);\n            __syncwarp();\n        }\n\n        // Flush all writes\n        tma_store_wait<0>();\n    } else {\n        #pragma unroll\n        for (int i = lane_id; i < hidden_int4; i += 32) {\n            // Read bias\n            // TODO: make it as a finer-grained template\n            int4 bias_0_value_int4, bias_1_value_int4;\n            if constexpr (kMaybeWithBias) {\n                bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);\n                bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);\n            }\n\n            // Read buffers\n            // TODO: maybe too many registers here\n            int4 recv_value_int4[kMaxNumRanks];\n            #pragma unroll\n            for (int j = 0; j < num_topk_ranks; ++j)\n                recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));\n\n            // Clean\n            // Reduce bias\n            float values[kDtypePerInt4] = {0};\n            if constexpr (kMaybeWithBias) {\n                auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);\n                auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);\n                #pragma unroll\n                for (int j = 0; j < kDtypePerInt4; ++j)\n                    values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);\n            }\n\n            // Reduce all-to-all results\n            #pragma unroll\n            for (int j = 0; j < num_topk_ranks; ++j) {\n                auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);\n                #pragma unroll\n                for (int k = 0; k < kDtypePerInt4; ++k)\n                    values[k] += static_cast<float>(recv_value_dtypes[k]);\n            }\n\n            // Cast back to `dtype_t` and write\n            int4 out_int4;\n            auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);\n            #pragma unroll\n            for (int j = 0; j < kDtypePerInt4; ++j)\n                out_dtypes[j] = static_cast<dtype_t>(values[j]);\n            st_na_global(combined_row + i, out_int4);\n        }\n    }\n\n    // Reduce `topk_weights`\n    if (lane_id < num_topk) {\n        float value = 0;\n        #pragma unroll\n        for (int i = 0; i < num_topk_ranks; ++i)\n            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);\n        st_na_global(combined_topk_weights + lane_id, value);\n    }\n\n    // Return the minimum top-k rank\n    return topk_ranks[0];\n}\n\ntemplate <bool kLowLatencyMode,\n          int kNumRDMARanks,\n          typename dtype_t,\n          int kNumCombineForwarderWarps,\n          int kNumTMABytesPerSenderWarp,\n          int kNumTMABytesPerForwarderWarp,\n          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),\n          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,\n          int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder,\n          int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS>\n__global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* combined_x,\n                                                                        float* combined_topk_weights,\n                                                                        const bool* is_combined_token_in_rank,\n                                                                        const int4* x,\n                                                                        const float* topk_weights,\n                                                                        const int4* bias_0,\n                                                                        const int4* bias_1,\n                                                                        const int* combined_rdma_head,\n                                                                        const int* combined_nvl_head,\n                                                                        const SourceMeta* src_meta,\n                                                                        const int* rdma_channel_prefix_matrix,\n                                                                        const int* rdma_rank_prefix_sum,\n                                                                        const int* gbl_channel_prefix_matrix,\n                                                                        int num_tokens,\n                                                                        int num_combined_tokens,\n                                                                        int hidden,\n                                                                        int num_topk,\n                                                                        void* rdma_buffer_ptr,\n                                                                        int num_max_rdma_chunked_send_tokens,\n                                                                        int num_max_rdma_chunked_recv_tokens,\n                                                                        void** buffer_ptrs,\n                                                                        int num_max_nvl_chunked_send_tokens,\n                                                                        int num_max_nvl_chunked_recv_tokens,\n                                                                        int rank,\n                                                                        int num_ranks) {\n    enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator };\n\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;\n    const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();\n    const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;\n    const bool is_forwarder_sm = sm_id % 2 == 1;\n\n    EP_DEVICE_ASSERT(num_topk <= 32);\n    EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);\n    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));\n    const auto hidden_bytes = hidden_int4 * sizeof(int4);\n    const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, 0, 0, num_topk);\n\n    // NOTES: we decouple a channel into 2 SMs\n    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;\n    auto role_meta = [=]() -> std::pair<WarpRole, int> {\n        auto warp_id = thread_id / 32;\n        if (not is_forwarder_sm) {\n            if (warp_id < NUM_MAX_NVL_PEERS) {\n                auto shuffled_warp_id = warp_id;\n                shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;\n                return {WarpRole::kNVLSender, shuffled_warp_id};\n            } else if (warp_id < kNumForwarders) {\n                return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS};\n            } else {\n                return {WarpRole::kCoordinator, 0};\n            }\n        } else {\n            if (warp_id < kNumForwarders) {\n                auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders;\n                return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};\n            } else {\n                return {WarpRole::kCoordinator, 0};\n            }\n        }\n    }();\n    auto warp_role = role_meta.first;\n    auto warp_id = role_meta.second;\n\n    EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1);\n    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;\n\n    if (warp_role == WarpRole::kNVLSender) {\n        // NVL producers\n        const auto dst_nvl_rank = warp_id;\n\n        // NVL layouts\n        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources\n        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];\n        auto nvl_channel_x = AsymBuffer<uint8_t>(dst_buffer_ptr,\n                                                 num_max_nvl_chunked_recv_tokens * num_bytes_per_token,\n                                                 NUM_MAX_NVL_PEERS,\n                                                 channel_id,\n                                                 num_channels,\n                                                 nvl_rank)\n                                 .advance_also(local_buffer_ptr);\n        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank)\n                                    .advance_also(dst_buffer_ptr);\n        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)\n                                    .advance_also(local_buffer_ptr);\n\n        // TMA stuffs\n        extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];\n        auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp;\n        auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);\n        uint32_t tma_phase = 0;\n        if (elect_one_sync()) {\n            mbarrier_init(tma_mbarrier, 1);\n            fence_barrier_init();\n            EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp);\n        }\n        __syncwarp();\n\n        // Get tasks for each RDMA lane\n        int token_start_idx = 0, token_end_idx = 0;\n        if (lane_id < kNumRDMARanks) {\n            int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;\n            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];\n            token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];\n        }\n        __syncwarp();\n\n        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer\n        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;\n        EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA peers\");\n\n        // Iterate over all tokens and send by chunks\n        int current_rdma_idx = channel_id % kNumRDMARanks;\n        while (true) {\n            // Exit if possible\n            if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))\n                break;\n\n            // Decide the next RDMA buffer to send\n            bool is_lane_ready = false;\n            auto start_time = clock64();\n            while (true) {\n                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;\n                is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and\n                    num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;\n                if (__any_sync(0xffffffff, is_lane_ready))\n                    break;\n\n                // Retry\n                if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx)\n                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);\n\n                // Timeout check\n                if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {\n                    printf(\n                        \"DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: \"\n                        \"%d, start: %d, end: %d\\n\",\n                        channel_id,\n                        rdma_rank,\n                        nvl_rank,\n                        dst_nvl_rank,\n                        lane_id,\n                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),\n                        cached_channel_tail_idx,\n                        token_start_idx,\n                        token_end_idx);\n                    trap();\n                }\n            }\n\n            // Sync token start index and count\n            for (int i = 0; i < kNumRDMARanks; ++i) {\n                current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks;\n                if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))\n                    continue;\n\n                // Sync token start index\n                auto token_idx = static_cast<int64_t>(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx));\n                int num_tokens_in_chunk =\n                    __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);\n\n                // Send by chunk\n                for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {\n                    // Get an empty slot\n                    int dst_slot_idx = 0;\n                    if (lane_id == current_rdma_idx) {\n                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;\n                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;\n                    }\n                    dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx);\n\n                    // Load data\n                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;\n                    auto shifted_x = x + token_idx * hidden_int4;\n                    tma_store_wait<0>();\n                    if (elect_one_sync()) {\n                        tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);\n                        mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);\n                    }\n                    __syncwarp();\n                    mbarrier_wait(tma_mbarrier, tma_phase);\n\n                    // Load source meta\n                    if (lane_id == num_topk)\n                        *reinterpret_cast<SourceMeta*>(tma_buffer + hidden_bytes) = ld_nc_global(src_meta + token_idx);\n\n                    // Load `topk_weights`\n                    if (lane_id < num_topk)\n                        *reinterpret_cast<float*>(tma_buffer + hidden_bytes + sizeof(SourceMeta) + lane_id * sizeof(float)) =\n                            ld_nc_global(topk_weights + token_idx * num_topk + lane_id);\n\n                    // Issue TMA store\n                    tma_store_fence();\n                    __syncwarp();\n                    if (elect_one_sync())\n                        tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false);\n                }\n                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;\n            }\n\n            // Move queue tail\n            tma_store_wait<0>();\n            __syncwarp();\n            if (lane_id < kNumRDMARanks and is_lane_ready)\n                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);\n        }\n    } else {\n        // Combiners and coordinators\n        // RDMA symmetric layout\n        auto rdma_channel_data = SymBuffer<int8_t>(\n            rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);\n        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);\n        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);\n\n        // NVL layouts\n        void* local_nvl_buffer = buffer_ptrs[nvl_rank];\n        void* nvl_buffers[NUM_MAX_NVL_PEERS];\n        #pragma unroll\n        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)\n            nvl_buffers[i] = buffer_ptrs[i];\n        auto nvl_channel_x =\n            AsymBuffer<uint8_t>(\n                local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels)\n                .advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);\n        auto nvl_channel_head =\n            AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)\n                .advance_also(local_nvl_buffer);\n        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels)\n                                    .advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);\n\n        // Combiner warp synchronization\n        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];\n        __shared__ volatile bool forwarder_retired[kNumForwarders];\n        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];\n        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];\n        auto sync_forwarder_smem = [=]() { asm volatile(\"barrier.sync 0, %0;\" ::\"r\"((kNumForwarders + 1) * 32)); };\n        auto sync_rdma_receiver_smem = [=]() { asm volatile(\"barrier.sync 1, %0;\" ::\"r\"((kNumRDMAReceivers + 1) * 32)); };\n\n        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {\n            // Receive from NVL ranks and forward to RDMA ranks\n            // NOTES: this part is using \"large warps\" for each RDMA ranks\n            const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;\n            const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;\n            auto send_buffer =\n                dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);\n            auto sync_large_warp = [=]() {\n                if (kNumWarpsPerForwarder == 1) {\n                    __syncwarp();\n                } else {\n                    asm volatile(\"bar.sync %0, %1;\" ::\"r\"(dst_rdma_rank + 2), \"r\"(kNumWarpsPerForwarder * 32));\n                }\n            };\n            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, \"Barriers are not enough\");\n\n            // TMA stuffs\n            constexpr int kNumStages = 2;\n            constexpr int kNumTMALoadBytes = sizeof(int4) * 32;\n            constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16;\n            EP_STATIC_ASSERT(kNumTMABufferBytesPerStage * kNumStages <= kNumTMABytesPerForwarderWarp, \"TMA buffer is not larger enough\");\n\n            extern __shared__ __align__(1024) uint8_t smem_buffer[];\n            auto smem_ptr = smem_buffer + warp_id * kNumStages * kNumTMABufferBytesPerStage;\n            auto tma_mbarrier = [=](const int& i) {\n                return reinterpret_cast<uint64_t*>(smem_ptr + i * kNumTMABufferBytesPerStage + kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1));\n            };\n            uint32_t tma_phase[kNumStages] = {0};\n            if (lane_id < kNumStages) {\n                mbarrier_init(tma_mbarrier(lane_id), 32);\n                fence_barrier_init();\n            }\n            __syncwarp();\n\n            // Advance to the corresponding NVL buffer\n            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token);\n            nvl_channel_head.advance(dst_rdma_rank);\n            nvl_channel_tail.advance(dst_rdma_rank);\n\n            // Clean shared memory and sync\n            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, \"Invalid number of NVL peers\");\n            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;\n            lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;\n            sync_forwarder_smem();\n\n            // Get count and cached head\n            int cached_nvl_channel_tail_idx = 0;\n            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];\n            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];\n            num_tokens_to_combine -= num_tokens_prefix;\n            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];\n            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;\n\n            // Iterate over all tokens and combine by chunks\n            for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {\n                // Check destination queue emptiness, or wait a buffer to be released\n                auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);\n                auto num_chunked_tokens = token_end_idx - token_start_idx;\n                auto start_time = clock64();\n                while (sub_warp_id == 0 and lane_id == 0) {\n                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`\n                    // Here, `token_start_idx` is the actual tail\n                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));\n                    if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)\n                        break;\n\n                    // Timeout check\n                    if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                        printf(\n                            \"DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: \"\n                            \"%d, chunked: %d\\n\",\n                            channel_id,\n                            rdma_rank,\n                            nvl_rank,\n                            dst_rdma_rank,\n                            ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)),\n                            token_start_idx,\n                            num_chunked_tokens);\n                        trap();\n                    }\n                }\n                sync_large_warp();\n\n                // Combine and write to the RDMA buffer\n                for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {\n                    // Read expected head\n                    EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA peers\");\n                    int expected_head = -1;\n                    if (lane_id < NUM_MAX_NVL_PEERS) {\n                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);\n                        expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1)\n                                          : (forwarder_nvl_head[warp_id][lane_id] = expected_head);\n                    }\n\n                    // Wait lanes to be ready\n                    start_time = clock64();\n                    while (cached_nvl_channel_tail_idx <= expected_head) {\n                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));\n\n                        // Timeout check\n                        if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {\n                            printf(\n                                \"DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, \"\n                                \"tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\\n\",\n                                channel_id,\n                                rdma_rank,\n                                nvl_rank,\n                                lane_id,\n                                dst_rdma_rank,\n                                cached_nvl_channel_tail_idx,\n                                token_idx,\n                                num_tokens_to_combine,\n                                sub_warp_id,\n                                kNumWarpsPerForwarder,\n                                expected_head);\n                            trap();\n                        }\n                    }\n\n                    // Combine current token\n                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;\n                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token;\n                    auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* {\n                        return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) +\n                            hidden_int4_idx;\n                    };\n                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float {\n                        return ld_nc_global(reinterpret_cast<float*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token +\n                                                                     hidden_bytes + sizeof(SourceMeta)) +\n                                            topk_idx);\n                    };\n                    combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS, true, kNumStages, kNumTMALoadBytes>(\n                        expected_head >= 0,\n                        expected_head,\n                        lane_id,\n                        hidden_int4,\n                        num_topk,\n                        static_cast<int4*>(shifted),\n                        reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),\n                        nullptr,\n                        nullptr,\n                        num_max_nvl_chunked_recv_tokens_per_rdma,\n                        get_addr_fn,\n                        recv_tw_fn,\n                        smem_ptr,\n                        tma_phase);\n\n                    // Update head\n                    if (lane_id < NUM_MAX_NVL_PEERS)\n                        expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1)\n                                          : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);\n                }\n                sync_large_warp();\n\n                // Issue RDMA send\n                if (sub_warp_id == kNumWarpsPerForwarder - 1) {\n                    if (dst_rdma_rank != rdma_rank) {\n                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;\n                        const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token;\n                        const auto dst_ptr =\n                            reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token);\n                        const auto src_ptr =\n                            reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token);\n                        nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr,\n                                                          src_ptr,\n                                                          num_bytes_per_msg,\n                                                          translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                          channel_id,\n                                                          lane_id,\n                                                          0);\n                    } else {\n                        memory_fence();\n                    }\n\n                    // Write new RDMA tail\n                    __syncwarp();\n                    if (elect_one_sync()) {\n                        nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank),\n                                                        num_chunked_tokens,\n                                                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                        channel_id,\n                                                        dst_rdma_rank == rdma_rank);\n                    }\n                }\n            }\n\n            // Retired\n            __syncwarp();\n            if (elect_one_sync())\n                forwarder_retired[warp_id] = true;\n        } else if (warp_role == WarpRole::kRDMAReceiver) {\n            // Receive from RDMA ranks and write to the output tensor\n            // Clean shared memory and sync\n            EP_DEVICE_ASSERT(kNumRDMARanks <= 32);\n            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;\n            lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;\n            sync_rdma_receiver_smem();\n\n            // The same tokens as the dispatch process\n            int token_start_idx, token_end_idx;\n            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);\n\n            // Iterate over all tokens and combine\n            int cached_channel_tail_idx = 0;\n            for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {\n                // Read expected head\n                EP_STATIC_ASSERT(kNumRDMARanks <= 32, \"Invalid number of RDMA peers\");\n                int expected_head = -1;\n                if (lane_id < kNumRDMARanks) {\n                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);\n                    (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1)\n                                        : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);\n                }\n\n                // Wait lanes to be ready\n                auto start_time = clock64();\n                while (cached_channel_tail_idx <= expected_head) {\n                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));\n\n                    // Timeout check\n                    if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                        printf(\n                            \"DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, \"\n                            \"expect: %d\\n\",\n                            channel_id,\n                            rdma_rank,\n                            nvl_rank,\n                            lane_id,\n                            cached_channel_tail_idx,\n                            token_idx,\n                            expected_head);\n                        trap();\n                    }\n                }\n                __syncwarp();\n\n                // Combine current token\n                auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* {\n                    return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) +\n                        hidden_int4_idx;\n                };\n                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float {\n                    return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) +\n                                                                       slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) +\n                                        topk_idx);\n                };\n                uint32_t dummy_tma_phases[2];\n                combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks, false, 2>(\n                    expected_head >= 0,\n                    expected_head,\n                    lane_id,\n                    hidden_int4,\n                    num_topk,\n                    combined_x + token_idx * hidden_int4,\n                    combined_topk_weights + token_idx * num_topk,\n                    bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,\n                    bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,\n                    num_max_rdma_chunked_recv_tokens,\n                    get_addr_fn,\n                    recv_tw_fn,\n                    nullptr,\n                    dummy_tma_phases);\n            }\n\n            // Retired\n            __syncwarp();\n            if (elect_one_sync())\n                rdma_receiver_retired[warp_id] = true;\n        } else {\n            // Coordinator\n            // Sync shared memory status\n            is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem();\n            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;\n\n            int last_rdma_head = 0;\n            int last_nvl_head[kNumRDMARanks] = {0};\n            int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;\n            int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;\n            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, \"Invalid number of forwarder warps\");\n            while (true) {\n                // Retired\n                if (not is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))\n                    break;\n                if (is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))\n                    break;\n\n                // Find minimum head for RDMA ranks\n                if (not is_forwarder_sm) {\n                    int min_head = std::numeric_limits<int>::max();\n                    #pragma unroll\n                    for (int i = 0; i < kNumRDMAReceivers; ++i)\n                        if (not rdma_receiver_retired[i])\n                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);\n                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and\n                        lane_id < kNumRDMARanks) {\n                        nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank),\n                                                        min_head - last_rdma_head,\n                                                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),\n                                                        channel_id + num_channels,\n                                                        dst_rdma_rank == rdma_rank);\n                        last_rdma_head = min_head;\n                    }\n                } else {\n                    // Find minimum head for NVL ranks\n                    #pragma unroll\n                    for (int i = 0; i < kNumRDMARanks; ++i) {\n                        int min_head = std::numeric_limits<int>::max();\n                        #pragma unroll\n                        for (int j = 0; j < num_warps_per_rdma_rank; ++j)\n                            if (not forwarder_retired[i * num_warps_per_rdma_rank + j])\n                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);\n                        if (min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS)\n                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);\n                    }\n                }\n\n                // Nanosleep and let other warps work\n                __nanosleep(NUM_WAIT_NANOSECONDS);\n            }\n        }\n    }\n}\n\nvoid combine(cudaDataType_t type,\n             void* combined_x,\n             float* combined_topk_weights,\n             const bool* is_combined_token_in_rank,\n             const void* x,\n             const float* topk_weights,\n             const void* bias_0,\n             const void* bias_1,\n             const int* combined_rdma_head,\n             const int* combined_nvl_head,\n             const void* src_meta,\n             const int* rdma_channel_prefix_matrix,\n             const int* rdma_rank_prefix_sum,\n             const int* gbl_channel_prefix_matrix,\n             int num_tokens,\n             int num_combined_tokens,\n             int hidden,\n             int num_topk,\n             void* rdma_buffer_ptr,\n             int num_max_rdma_chunked_send_tokens,\n             int num_max_rdma_chunked_recv_tokens,\n             void** buffer_ptrs,\n             int num_max_nvl_chunked_send_tokens,\n             int num_max_nvl_chunked_recv_tokens,\n             int rank,\n             int num_ranks,\n             cudaStream_t stream,\n             int num_channels,\n             bool low_latency_mode) {\n    constexpr int kNumCombineForwarderWarps = 24;\n    constexpr int kNumTMABytesPerSenderWarp = 16384;\n    constexpr int kNumTMABytesPerForwarderWarp = 9248;\n    constexpr int smem_size =\n        std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps);\n\n#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                           \\\n    {                                                                                 \\\n        auto combine_func = low_latency_mode ? combine<true,                          \\\n                                                       num_rdma_ranks,                \\\n                                                       nv_bfloat16,                   \\\n                                                       kNumCombineForwarderWarps,     \\\n                                                       kNumTMABytesPerSenderWarp,     \\\n                                                       kNumTMABytesPerForwarderWarp>  \\\n                                             : combine<false,                         \\\n                                                       num_rdma_ranks,                \\\n                                                       nv_bfloat16,                   \\\n                                                       kNumCombineForwarderWarps,     \\\n                                                       kNumTMABytesPerSenderWarp,     \\\n                                                       kNumTMABytesPerForwarderWarp>; \\\n        SET_SHARED_MEMORY_FOR_TMA(combine_func);                                      \\\n        LAUNCH_KERNEL(&cfg,                                                           \\\n                      combine_func,                                                   \\\n                      reinterpret_cast<int4*>(combined_x),                            \\\n                      combined_topk_weights,                                          \\\n                      is_combined_token_in_rank,                                      \\\n                      reinterpret_cast<const int4*>(x),                               \\\n                      topk_weights,                                                   \\\n                      reinterpret_cast<const int4*>(bias_0),                          \\\n                      reinterpret_cast<const int4*>(bias_1),                          \\\n                      combined_rdma_head,                                             \\\n                      combined_nvl_head,                                              \\\n                      reinterpret_cast<const SourceMeta*>(src_meta),                  \\\n                      rdma_channel_prefix_matrix,                                     \\\n                      rdma_rank_prefix_sum,                                           \\\n                      gbl_channel_prefix_matrix,                                      \\\n                      num_tokens,                                                     \\\n                      num_combined_tokens,                                            \\\n                      hidden,                                                         \\\n                      num_topk,                                                       \\\n                      rdma_buffer_ptr,                                                \\\n                      num_max_rdma_chunked_send_tokens,                               \\\n                      num_max_rdma_chunked_recv_tokens,                               \\\n                      buffer_ptrs,                                                    \\\n                      num_max_nvl_chunked_send_tokens,                                \\\n                      num_max_nvl_chunked_recv_tokens,                                \\\n                      rank,                                                           \\\n                      num_ranks);                                                     \\\n    }                                                                                 \\\n    break\n\n    int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;\n    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);\n    int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;\n    EP_HOST_ASSERT(num_rdma_ranks <= kNumCombineForwarderWarps);\n    EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0);\n    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);\n    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks >\n                   std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));\n    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >= num_max_nvl_chunked_send_tokens);\n    EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);\n    EP_HOST_ASSERT(type == CUDA_R_16BF);\n\n    SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream);\n    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);\n#undef COMBINE_LAUNCH_CASE\n}\n\n}  // namespace internode\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/internode_ll.cu",
    "content": "#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"ibgda_device.cuh\"\n#include \"launch.cuh\"\n\nnamespace deep_ep {\n\nnamespace internode_ll {\n\ntemplate <bool use_warp_sync = false>\n__forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) {\n    if (mask_buffer_ptr == nullptr) {\n        return false;\n    }\n    if constexpr (use_warp_sync) {\n        return __shfl_sync(0xffffffff, ld_acquire_global(mask_buffer_ptr + rank), 0) != 0;\n    } else {\n        return ld_acquire_global(mask_buffer_ptr + rank) != 0;\n    }\n}\n\ntemplate <int kNumThreads>\n__forceinline__ __device__ void barrier(int thread_id, int rank, int num_ranks, int* mask_buffer_ptr, int* sync_buffer_ptr) {\n    EP_DEVICE_ASSERT(kNumThreads >= num_ranks);\n\n    // Quiet all QPs\n    auto qps_per_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;\n\n    for (int i = thread_id; i < qps_per_rank * (num_ranks - 1); i += kNumThreads) {\n        auto dst_rank = (rank + 1 + i / qps_per_rank) % num_ranks;\n        auto qp_id = i % qps_per_rank;\n        nvshmemi_ibgda_quiet(dst_rank, qp_id);\n    }\n\n    // Update local counter\n    if (thread_id == 0)\n        atomicAdd(sync_buffer_ptr + rank, -1);\n    __syncthreads();\n\n    int cnt = sync_buffer_ptr[rank];\n    // Update remote counter and wait for local counter to be updated\n    if (thread_id < num_ranks && thread_id != rank) {\n        const auto dst_rank = thread_id;\n        const auto dst_ptr = reinterpret_cast<uint64_t>(sync_buffer_ptr + rank);\n        const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);\n\n        if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {\n            if (dst_p2p_ptr == 0) {\n                nvshmemi_ibgda_rma_p(reinterpret_cast<int*>(dst_ptr), cnt, dst_rank, 0);\n            } else {\n                st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), cnt);\n            }\n\n            auto start_time = clock64();\n            uint64_t wait_recv_cost = 0;\n            while (ld_acquire_sys_global(sync_buffer_ptr + dst_rank) != cnt            // remote is not ready\n                   && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES  // not timeout\n            )\n                ;\n            // Mask rank if timeout\n            if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {\n                printf(\"Warning: DeepEP timeout for barrier, rank %d, dst_rank %d\\n\", rank, dst_rank);\n                if (mask_buffer_ptr == nullptr)\n                    trap();\n                atomicExch(mask_buffer_ptr + dst_rank, 1);\n            }\n        }\n    }\n    __syncthreads();\n}\n\ntemplate <int kNumThreads>\n__launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0,\n                                                                           int num_clean_int_0,\n                                                                           int* clean_1,\n                                                                           int num_clean_int_1,\n                                                                           int rank,\n                                                                           int num_ranks,\n                                                                           int* mask_buffer_ptr,\n                                                                           int* sync_buffer_ptr) {\n    auto thread_id = static_cast<int>(threadIdx.x);\n\n    // Barrier before cleaning (in case of unfinished chunked EP)\n    if (sync_buffer_ptr == nullptr)\n        nvshmemx_barrier_all_block();\n    else\n        barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);\n\n    // Clean\n    #pragma unroll\n    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)\n        clean_0[i] = 0;\n    #pragma unroll\n    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)\n        clean_1[i] = 0;\n\n    // Barrier after cleaning (make sure the low-latency mode works fine)\n    if (sync_buffer_ptr == nullptr)\n        nvshmemx_barrier_all_block();\n    else\n        barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);\n}\n\nvoid clean_low_latency_buffer(int* clean_0,\n                              int num_clean_int_0,\n                              int* clean_1,\n                              int num_clean_int_1,\n                              int rank,\n                              int num_ranks,\n                              int* mask_buffer_ptr,\n                              int* sync_buffer_ptr,\n                              cudaStream_t stream) {\n    constexpr int kNumThreads = 256;\n\n    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);\n\n    LAUNCH_KERNEL(&cfg,\n                  clean_low_latency_buffer<kNumThreads>,\n                  clean_0,\n                  num_clean_int_0,\n                  clean_1,\n                  num_clean_int_1,\n                  rank,\n                  num_ranks,\n                  mask_buffer_ptr,\n                  sync_buffer_ptr);\n}\n\ntemplate <bool kUseFP8, bool kUseUE8M0, int kHidden>\n__global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x,\n                                                    void* packed_recv_x_scales,\n                                                    int* packed_recv_src_info,\n                                                    int64_t* packed_recv_layout_range,\n                                                    int* packed_recv_count,\n                                                    int* mask_buffer_ptr,\n                                                    int* cumulative_local_expert_recv_stats,\n                                                    int64_t* dispatch_wait_recv_cost_stats,\n                                                    void* rdma_recv_x,\n                                                    int* rdma_recv_count,\n                                                    void* rdma_x,\n                                                    const void* x,\n                                                    const topk_idx_t* topk_idx,\n                                                    int* atomic_counter_per_expert,\n                                                    int* atomic_finish_counter_per_expert,\n                                                    int* next_clean,\n                                                    int num_next_clean_int,\n                                                    int num_tokens,\n                                                    int num_max_dispatch_tokens_per_rank,\n                                                    int num_topk,\n                                                    int num_experts,\n                                                    int rank,\n                                                    int num_ranks,\n                                                    int num_warp_groups,\n                                                    int num_warps_per_group,\n                                                    bool round_scale,\n                                                    int phases) {\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    const auto thread_id = static_cast<int>(threadIdx.x);\n    const auto warp_id = thread_id / 32, lane_id = get_lane_id();\n    const auto num_sms = static_cast<int>(gridDim.x);\n    const auto num_warps = num_warp_groups * num_warps_per_group;\n    const auto num_local_experts = num_experts / num_ranks;\n    const auto warp_group_id = warp_id / num_warps_per_group;\n    const auto sub_warp_id = warp_id % num_warps_per_group;\n    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;\n\n    // May extract UE8M0 from the scales\n    using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;\n    using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;\n    EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, \"Invalid vector length\");\n\n    // FP8 staffs\n    constexpr int kNumPerChannels = 128;\n    const int num_scales = kHidden / kNumPerChannels;\n    const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));\n    const size_t hidden_int4 = hidden_bytes / sizeof(int4);\n\n    // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales\n    // NOTES: currently we have 3 reserved int fields for future use\n    using vec_t = std::conditional_t<kUseFP8, int2, int4>;\n    const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));\n    const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);\n    EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);\n\n    // Expert counts\n    constexpr int kNumMaxWarpGroups = 32;\n    __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];\n\n    // Sending phase\n    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)\n        goto LOW_LATENCY_DISPATCH_RECV;\n\n    // There are 2 kinds of warps in this part:\n    // 1. The first-kind warps for FP8 cast and sending top-k tokens\n    // 2. The last warp for reading `topk_idx` and count for per-expert information\n    if (warp_id < num_warps - 1) {\n        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);\n        EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, \"Invalid hidden\");\n        EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, \"Invalid vectorization\");\n        const auto num_threads = (num_warps - 1) * 32;\n        const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;\n\n        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {\n            const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;\n            const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);\n            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));\n            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);\n\n            // Overlap top-k index read and source token index writes\n            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;\n            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;\n\n            // FP8 cast\n            EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, \"Must use the full warp to reduce\");\n            #pragma unroll\n            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {\n                // Read\n                auto int4_value = __ldg(x_int4 + i);\n\n                if constexpr (kUseFP8) {\n                    // Calculate local amax\n                    auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);\n                    float fp32_values[kNumElemsPerRead];\n                    float amax = kFP8Margin, scale, scale_inv;\n                    #pragma unroll\n                    for (int j = 0; j < kNumElemsPerRead; ++j) {\n                        fp32_values[j] = static_cast<float>(bf16_values[j]);\n                        amax = fmaxf(amax, fabsf(fp32_values[j]));\n                    }\n\n                    // Reduce amax and scale\n                    EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, \"Invalid vectorization\");\n                    amax = warp_reduce_max<16>(amax);\n                    calculate_fp8_scales(amax, scale, scale_inv, round_scale);\n                    if (lane_id == 0 or lane_id == 16)\n                        rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;\n\n                    // Cast into send buffer\n                    vec_t int2_value;\n                    auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);\n                    #pragma unroll\n                    for (int j = 0; j < kNumElemsPerRead; j += 2) {\n                        float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};\n                        fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);\n                    }\n                    rdma_x_vec[i] = int2_value;\n                } else {\n                    // Reinterpret-cast is for C++14 compatibility\n                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);\n                }\n            }\n            asm volatile(\"bar.sync 1, %0;\" ::\"r\"(num_threads));\n\n            // Issue IBGDA sends\n            if (dst_expert_idx >= 0) {\n                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;\n                slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);\n                const auto dst_rank = dst_expert_idx / num_local_experts;\n                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;\n                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);\n                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +\n                    dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +\n                    rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg;\n                const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);\n                if (not is_rank_masked<true>(mask_buffer_ptr, dst_rank)) {\n                    if (dst_p2p_ptr == 0) {\n                        nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);\n                    } else {\n                        // NOTES: only 2 load iterations for 7K hidden with 8 unrolls\n                        const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);\n                        const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);\n                        UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);\n                    }\n                }\n\n                // Increase counter after finishing\n                __syncwarp();\n                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;\n            }\n        }\n    } else if (warp_id == num_warps - 1) {\n        EP_DEVICE_ASSERT(num_sms > 1);\n        if (sm_id == 0) {\n            // The first SM is also responsible for checking QPs\n            EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);\n\n            // The first SM is also responsible for cleaning the next buffer\n            #pragma unroll\n            for (int i = lane_id; i < num_next_clean_int; i += 32)\n                next_clean[i] = 0;\n\n            // Notify before executing `int_p`\n            __syncwarp();\n            #pragma unroll\n            for (int i = lane_id; i < num_experts; i += 32)\n                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);\n        }\n\n        // This SM should be responsible for some destination experts, read `topk_idx` for them\n        int expert_count[kNumMaxWarpGroups] = {0};\n        const auto expert_begin_idx = sm_id * num_warp_groups;\n        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);\n\n        // Per lane count\n        #pragma unroll 8\n        for (int i = lane_id; i < num_tokens * num_topk; i += 32) {\n            auto idx = static_cast<int>(__ldg(topk_idx + i));\n            if (idx >= expert_begin_idx and idx < expert_end_idx)\n                expert_count[idx - expert_begin_idx]++;\n        }\n\n        // Warp reduce\n        #pragma unroll\n        for (int i = expert_begin_idx; i < expert_end_idx; ++i) {\n            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);\n            if (lane_id == 0) {\n                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;\n                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);\n            }\n        }\n    }\n    __syncthreads();\n\n    // Issue count sends\n    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {\n        const auto dst_rank = responsible_expert_idx / num_local_experts;\n        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;\n        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];\n\n        // Wait local sends issued and send expert counts\n        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2)\n            ;\n        auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);\n        auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);\n        if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {\n            if (dst_p2p_ptr == 0) {\n                nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);\n            } else {\n                st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);\n            }\n        }\n\n        // Clean workspace for next use\n        atomic_counter_per_expert[responsible_expert_idx] = 0;\n        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;\n\n        // Clean `packed_recv_count`\n        if (dst_rank == 0)\n            packed_recv_count[dst_expert_local_idx] = 0;\n    }\n    __syncwarp();\n\n// Receiving phase\nLOW_LATENCY_DISPATCH_RECV:\n    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)\n        return;\n\n    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible\n    if (phases & LOW_LATENCY_SEND_PHASE)\n        cg::this_grid().sync();\n\n    // Receiving and packing\n    if (responsible_expert_idx < num_experts) {\n        const auto src_rank = responsible_expert_idx / num_local_experts;\n        const auto local_expert_idx = responsible_expert_idx % num_local_experts;\n        const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +\n            local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +\n            src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;\n        const auto recv_x_int4 =\n            static_cast<int4*>(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;\n        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;\n        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;\n        const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));\n        const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +\n            local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;\n\n        // Shared between sub-warps in warp groups\n        __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];\n\n        // Wait tokens to arrive\n        // NOTES: using sub-warp 1 to overlap with sub-warp 0\n        int num_recv_tokens = 0, recv_token_begin_idx;\n        EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);\n        if (sub_warp_id == 1 and lane_id == 0) {\n            auto start_time = clock64();\n            uint64_t wait_recv_cost = 0;\n            if (not is_rank_masked(mask_buffer_ptr, src_rank)) {\n                while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) ==\n                           0                                                               // data not arrived\n                       && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES  // not timeout\n                )\n                    ;\n            }\n            // Do not receive tokens if rank timeout or masked\n            if (num_recv_tokens == 0)\n                num_recv_tokens = -1;\n            // Mask rank if timeout\n            if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {\n                printf(\"Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\\n\",\n                       rank,\n                       local_expert_idx,\n                       src_rank);\n                if (mask_buffer_ptr == nullptr)\n                    trap();\n                atomicExch(mask_buffer_ptr + src_rank, 1);\n            }\n\n            num_recv_tokens = -num_recv_tokens - 1;\n            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);\n            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;\n            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;\n            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);\n\n            // Add stats for diagnosis\n            if (cumulative_local_expert_recv_stats != nullptr)\n                atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);\n            if (dispatch_wait_recv_cost_stats != nullptr)\n                atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost);\n        }\n        asm volatile(\"bar.sync %0, %1;\" ::\"r\"(warp_group_id + 2), \"r\"(num_warps_per_group * 32));\n        num_recv_tokens = shared_num_recv_tokens[warp_group_id];\n        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];\n\n        // Copy tokens\n        EP_DEVICE_ASSERT(num_scales <= 64);\n        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {\n            // Copy source info\n            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);\n            if (lane_id == 0)\n                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);\n            __syncwarp();\n\n            // Copy data\n            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls\n            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));\n            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;\n            UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);\n\n            // Copy scales\n            if constexpr (kUseFP8) {\n                // Equivalent CuTe layout:\n                //   (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))\n                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);\n                const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));\n                const auto token_idx = recv_token_begin_idx + i;\n                const auto token_stride = num_elems_per_pack;\n                const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;\n                if (lane_id < num_scales) {\n                    const auto pack_idx = lane_id / num_elems_per_pack;\n                    const auto elem_idx = lane_id % num_elems_per_pack;\n                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));\n                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;\n                }\n                if (lane_id + 32 < num_scales) {\n                    const auto pack_idx = (lane_id + 32) / num_elems_per_pack;\n                    const auto elem_idx = (lane_id + 32) % num_elems_per_pack;\n                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));\n                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;\n                }\n            }\n        }\n    }\n}\n\nvoid dispatch(void* packed_recv_x,\n              void* packed_recv_x_scales,\n              int* packed_recv_src_info,\n              int64_t* packed_recv_layout_range,\n              int* packed_recv_count,\n              int* mask_buffer_ptr,\n              int* cumulative_local_expert_recv_stats,\n              int64_t* dispatch_wait_recv_cost_stats,\n              void* rdma_recv_x,\n              int* rdma_recv_count,\n              void* rdma_x,\n              const void* x,\n              const topk_idx_t* topk_idx,\n              int* next_clean,\n              int num_next_clean_int,\n              int num_tokens,\n              int hidden,\n              int num_max_dispatch_tokens_per_rank,\n              int num_topk,\n              int num_experts,\n              int rank,\n              int num_ranks,\n              bool use_fp8,\n              bool round_scale,\n              bool use_ue8m0,\n              void* workspace,\n              int num_device_sms,\n              cudaStream_t stream,\n              int phases) {\n    constexpr int kNumMaxTopK = 11;\n    const int num_warp_groups = ceil_div(num_experts, num_device_sms);\n    const int num_warps_per_group = 32 / num_warp_groups;\n    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);\n    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);\n\n    const auto num_warps = num_warp_groups * num_warps_per_group;\n    const auto num_sms = ceil_div(num_experts, num_warp_groups);\n    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);\n\n    // Workspace checks\n    auto atomic_counter_per_expert = static_cast<int*>(workspace);\n    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;\n    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);\n\n    // FP8 checks\n    if (use_ue8m0)\n        EP_HOST_ASSERT(round_scale and \"UE8M0 SF requires `round_scale=True`\");\n\n#define DISPATCH_LAUNCH_CASE(hidden)                         \\\n    {                                                        \\\n        auto dispatch_func = dispatch<false, false, hidden>; \\\n        if (use_fp8 and not use_ue8m0)                       \\\n            dispatch_func = dispatch<true, false, hidden>;   \\\n        if (use_fp8 and use_ue8m0)                           \\\n            dispatch_func = dispatch<true, true, hidden>;    \\\n        LAUNCH_KERNEL(&cfg,                                  \\\n                      dispatch_func,                         \\\n                      packed_recv_x,                         \\\n                      packed_recv_x_scales,                  \\\n                      packed_recv_src_info,                  \\\n                      packed_recv_layout_range,              \\\n                      packed_recv_count,                     \\\n                      mask_buffer_ptr,                       \\\n                      cumulative_local_expert_recv_stats,    \\\n                      dispatch_wait_recv_cost_stats,         \\\n                      rdma_recv_x,                           \\\n                      rdma_recv_count,                       \\\n                      rdma_x,                                \\\n                      x,                                     \\\n                      topk_idx,                              \\\n                      atomic_counter_per_expert,             \\\n                      atomic_finish_counter_per_expert,      \\\n                      next_clean,                            \\\n                      num_next_clean_int,                    \\\n                      num_tokens,                            \\\n                      num_max_dispatch_tokens_per_rank,      \\\n                      num_topk,                              \\\n                      num_experts,                           \\\n                      rank,                                  \\\n                      num_ranks,                             \\\n                      num_warp_groups,                       \\\n                      num_warps_per_group,                   \\\n                      round_scale,                           \\\n                      phases);                               \\\n    }                                                        \\\n    break\n\n    SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);\n    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);\n#undef DISPATCH_LAUNCH_CASE\n}\n\ntemplate <int kNumSendUnrolls>\n__forceinline__ __device__ int logfmt_encode(void* buffer, nv_bfloat162* shared_amaxmin, const int& lane_id) {\n    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);\n    constexpr float kLogThreshold = 0;\n    constexpr float kMinClip = 32;  // `== log_2(2 ^ (2 ^ 5))`\n    constexpr int kNumBits = 10;\n    constexpr int kNumValues = 1 << (kNumBits - 1);\n\n    int4 int4_values[kNumSendUnrolls];\n    const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);\n    const auto& bf162_values = reinterpret_cast<nv_bfloat162*>(int4_values);\n\n    // Calculate lane offset\n    const auto& ld_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4)));\n    const auto& st_buffer =\n        reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16));\n\n    // Local log amax\n    auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16);\n    auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16);\n    uint32_t local_signs = 0;\n    #pragma unroll\n    for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++k) {\n        // TODO: eliminate bank conflicts\n        uint32_values[k] = ld_buffer[k];\n        local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2);\n        local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1);\n        uint32_values[k] &= 0x7fff7fff;\n\n        bf162_amax = __hmax2(bf162_amax, bf162_values[k]);\n        bf162_amin = __hmin2(bf162_amin, bf162_values[k]);\n    }\n\n    // Reduce per 128 channels\n    // TODO: figure out how hardware do 2-byte min/max\n    auto amax = std::max(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));\n    auto amin = std::min(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));\n    constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4));\n    amax = warp_reduce_max<kNumLanesToReduce>(amax);\n    amin = warp_reduce_min<kNumLanesToReduce>(amin);\n\n    // Write min/max into the shared memory\n    if (shared_amaxmin != nullptr)\n        *shared_amaxmin = __nv_bfloat162(amax, amin);\n    __syncwarp();\n\n    // Calculate log amin/amax float\n    const auto& log_amax = log2f_approx(amax);\n    const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip);\n    const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);\n\n    // Case into LogFMT-10 if satisfied\n    if (enable_cast) {\n        const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);\n        const auto step_inv = 1.0f / step;\n        const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;\n        const auto fused_rounding = rounding - log_amin * step_inv;\n\n        // Pack every 256 bits into 160 bits\n        EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, \"kNumSendUnrolls == 2 or 4 only\");\n        uint32_t encoded[kNumElemsPerInt4 * 2];\n        #pragma unroll 1\n        for (int i = 0; i < kNumSendUnrolls / 2; ++i) {\n            #pragma unroll\n            for (int k = 0; k < kNumElemsPerInt4; ++k) {\n                const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]);\n                encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0));\n                encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0));\n            }\n            st_buffer[i * 5 + 0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);\n            st_buffer[i * 5 + 1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);\n            st_buffer[i * 5 + 2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);\n            st_buffer[i * 5 + 3] =\n                (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);\n            st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u));\n        }\n        tma_store_fence();\n        __syncwarp();\n    }\n\n    // Return TMA copy bytes\n    return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)) : (32 * (kNumSendUnrolls * sizeof(int4)));\n}\n\ntemplate <int kNumLanes, int kNumSendUnrolls, int kNumRecvUnrolls>\n__forceinline__ __device__ void logfmt_check_amaxmin(\n    uint8_t* meta_buffer, float2* shared_log_amax, float2* shared_log_amin, int* shared_cast_info, const int lane_id) {\n    constexpr float kLogThreshold = 0;\n    constexpr float kMinClip = 32;  // `== log_2(2 ^ (2 ^ 5))`\n\n    bool enable_cast = true;\n    if (lane_id < kNumLanes) {\n        // Calculate log amin/amax float\n        auto amaxmin2 = reinterpret_cast<uint64_t*>(meta_buffer)[lane_id];\n        const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);\n        float log_amax[2], log_amin[2];\n        #pragma unroll\n        for (int i = 0; i < 2; ++i) {\n            auto amax = static_cast<float>(bf162_amaxmin[i].x);\n            auto amin = static_cast<float>(bf162_amaxmin[i].y);\n            log_amax[i] = log2f_approx(amax);\n            log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip);\n            enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];\n        }\n        shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]);\n        shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]);\n    }\n\n    const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls) : 0u;\n    const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));\n\n    if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0)\n        shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);\n    __syncwarp();\n}\n\ntemplate <int kNumRecvUnrolls>\n__forceinline__ __device__ void decode_and_accumulate(\n    uint32_t* ld_buffer, float* accum, const float& log_amax, const float& log_amin, const bool& enable_cast, const float& weight) {\n    if (enable_cast) {\n        constexpr int kNumBits = 10;\n        constexpr int kNumValues = 1 << (kNumBits - 1);\n\n        const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);\n        auto decode = [=](const uint32_t& encoded, const uint32_t& sign) {\n            const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin);\n            return sign ? -decoded : decoded;\n        };\n\n        EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, \"kNumRecvUnrolls == 2 or 4 only\");\n        #pragma unroll\n        for (int i = 0; i < kNumRecvUnrolls / 2; ++i) {\n            uint32_t concat[6];\n            concat[0] = ld_buffer[i * 5];\n            #pragma unroll\n            for (int k = 1; k < 5; ++k)\n                concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5));\n            concat[5] = ld_buffer[i * 5 + 4] >> 7;\n\n            const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16;\n            #pragma unroll\n            for (int k = 0; k < 5; ++k) {\n                accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;\n                accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;\n                accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;\n            }\n            accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;\n        }\n    } else {\n        #pragma unroll\n        for (int k = 0; k < kNumRecvUnrolls * 4; ++k) {\n            auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k);\n            accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;\n            accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;\n        }\n    }\n}\n\ntemplate <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>\n__global__ __launch_bounds__(1024, 1) void combine(void* combined_x,\n                                                   void* rdma_recv_x,\n                                                   int* rdma_recv_flag,\n                                                   void* rdma_send_x,\n                                                   const void* x,\n                                                   const topk_idx_t* topk_idx,\n                                                   const float* topk_weights,\n                                                   const int* src_info,\n                                                   const int64_t* layout_range,\n                                                   int* mask_buffer_ptr,\n                                                   int64_t* combine_wait_recv_cost_stats,\n                                                   int* next_clean,\n                                                   int num_next_clean_int,\n                                                   int* atomic_clean_flag,\n                                                   int num_combined_tokens,\n                                                   int hidden,\n                                                   int num_topk,\n                                                   int num_max_dispatch_tokens_per_rank,\n                                                   int num_experts,\n                                                   int rank,\n                                                   int num_ranks,\n                                                   int num_warp_groups,\n                                                   int num_warps_per_group,\n                                                   int phases,\n                                                   bool zero_copy) {\n    const auto sm_id = __shfl_sync(0xffffffff, static_cast<int>(blockIdx.x), 0);\n    const auto num_sms = __shfl_sync(0xffffffff, static_cast<int>(gridDim.x), 0);\n    const auto thread_id = static_cast<int>(threadIdx.x);\n    const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);\n    const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();\n    const auto num_local_experts = num_experts / num_ranks;\n    const auto warp_group_id = warp_id / num_warps_per_group;\n    const auto sub_warp_id = warp_id % num_warps_per_group;\n    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;\n\n    extern __shared__ __align__(1024) uint8_t smem_buffer[];\n\n    // Data type staffs\n    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);\n    constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;\n\n    // Use different unroll factors for send and recv phases\n    constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;\n    constexpr int kNumRecvUnrolls = 2;\n    constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);\n    EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, \"Invalid hidden\");\n    EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, \"Invalid unrolls\");\n    EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, \"Invalid hidden\");\n    EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, \"Invalid unroll factors\");\n\n    // Message package\n    EP_STATIC_ASSERT(kHidden % 128 == 0, \"Invalid hidden\");\n    constexpr int kNumDivisions = kHidden / 128;\n    constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162);\n    constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes;\n    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, \"Invalid vectorization\");\n\n    // Sending phase\n    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)\n        goto LOW_LATENCY_COMBINE_RECV;\n\n    // Clean up next buffer\n    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {\n        #pragma unroll\n        for (int i = lane_id; i < num_next_clean_int; i += 32)\n            next_clean[i] = 0;\n\n        // Notify before executing `int_p`\n        __syncwarp();\n        if (lane_id == 0)\n            atomic_add_release_global(atomic_clean_flag, num_experts);\n    }\n\n    // Issue IBGDA sends\n    if (responsible_expert_idx < num_experts) {\n        const auto dst_rank = responsible_expert_idx / num_local_experts;\n        const auto local_expert_idx = responsible_expert_idx % num_local_experts;\n        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;\n        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);\n        const auto local_x =\n            static_cast<const int4*>(x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;\n        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;\n        const auto rdma_send_x_vec =\n            static_cast<uint8_t*>(rdma_send_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;\n\n        // Unpack layout\n        int offset, num_tokens_to_send;\n        unpack2(layout, num_tokens_to_send, offset);\n\n        // TMA stuffs\n        constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;\n        constexpr int kNumStages = 3;\n        constexpr int kNumPrefetch = 1;\n        EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, \"Invalid stages\");\n\n        auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes);\n        uint32_t tma_phase = 0;\n        auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });\n        auto full_barriers = PatternVisitor(\n            [=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });\n        auto meta_buffers = kUseLogFMT ? reinterpret_cast<nv_bfloat162*>(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr;\n        EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, \"TMA buffer size exceed limit\");\n\n        // Initialize m-barriers\n        if (lane_id < kNumStages) {\n            mbarrier_init(full_barriers[lane_id], 1);\n            fence_barrier_init();\n        }\n        __syncwarp();\n\n        constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls);\n        auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {\n            tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes);\n            mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes);\n        };\n        auto get_num_tma_bytes = [&](const int& offset_int4) {\n            return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));\n        };\n\n        // Issue IBGDA send\n        if (not is_rank_masked<true>(mask_buffer_ptr, dst_rank)) {\n            for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {\n                const auto x_int4 = local_x + token_idx * hidden_bf16_int4;\n                const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);\n                const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);\n\n                // Copy directly to local rank, or copy to buffer and issue RDMA\n                const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);\n                const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);\n                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +\n                    (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;\n                const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);\n                int num_send_bytes = hidden * sizeof(nv_bfloat16);\n\n                if (not zero_copy or dst_p2p_ptr != 0) {\n                    // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`\n                    const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;\n                    const auto cpy_dst_int4_ptr =\n                        dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);\n\n                    // Prefetch\n                    if (elect_one_sync())\n                        tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));\n                    __syncwarp();\n\n                    int tma_offset_bytes = kNumMetaBytes;\n                    #pragma unroll\n                    for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++iter_idx) {\n                        // Load the next iteration\n                        const int& stage_idx = iter_idx % kNumStages;\n                        const int& next_stage_idx = (iter_idx + 1) % kNumStages;\n                        if (iter_idx + 1 < kNumIters and elect_one_sync()) {\n                            tma_store_wait<kNumStages - kNumPrefetch - 1>();\n                            const auto& offset_int4 = i + 32 * kNumSendUnrolls;\n                            tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));\n                        }\n                        __syncwarp();\n\n                        // Wait the current TMA arrival\n                        EP_STATIC_ASSERT(kNumStages < 32, \"Too many stages\");\n                        mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);\n                        if constexpr (kUseLogFMT) {\n                            // Cast if possible\n                            constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;\n                            int num_tma_bytes = logfmt_encode<kNumSendUnrolls>(\n                                tma_buffers[stage_idx],\n                                // NOTES: only the leader lane will write the result\n                                (i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,\n                                lane_id);\n                            if (elect_one_sync())\n                                tma_store_1d(\n                                    tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);\n                            tma_offset_bytes += num_tma_bytes;\n                        } else {\n                            // BF16 original values\n                            if (elect_one_sync())\n                                tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));\n                        }\n                        __syncwarp();\n                    }\n\n                    // Store metadata (min/max values) for LogFMT\n                    if constexpr (kUseLogFMT) {\n                        num_send_bytes = tma_offset_bytes;\n                        if (elect_one_sync())\n                            tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);\n                    }\n\n                    // Flush all stores\n                    tma_store_wait<0>();\n                    __syncwarp();\n                }\n\n                // Issue RDMA\n                // NOTES: for zero-copy mode, we assume the data is already in the send buffer\n                if (dst_p2p_ptr == 0)\n                    nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);\n            }\n        }\n\n        // Put the finishing flag\n        EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);\n        asm volatile(\"bar.sync %0, %1;\" ::\"r\"(warp_group_id + 1), \"r\"(num_warps_per_group * 32));\n        if (sub_warp_id == 1 and lane_id == 0) {\n            while (ld_acquire_global(atomic_clean_flag) == 0)\n                ;\n            auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);\n            auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);\n            if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {\n                if (dst_p2p_ptr == 0) {\n                    nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);\n                } else {\n                    st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);\n                }\n            }\n            atomic_add_release_global(atomic_clean_flag, -1);\n        }\n        __syncwarp();\n\n        // Destroy m-barriers\n        if (lane_id < kNumStages) {\n            mbarrier_inval(full_barriers[lane_id]);\n            fence_barrier_init();\n        }\n        __syncwarp();\n    }\n\n// Receiving phase\nLOW_LATENCY_COMBINE_RECV:\n    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)\n        return;\n\n    // Wait all ranks to arrive\n    if (responsible_expert_idx < num_experts) {\n        EP_DEVICE_ASSERT(num_warps_per_group > 1);\n        if (sub_warp_id == 0 and lane_id == 0) {\n            const auto src_rank = responsible_expert_idx / num_local_experts;\n            auto start_time = clock64();\n            uint64_t wait_recv_cost = 0;\n            if (not is_rank_masked(mask_buffer_ptr, src_rank)) {\n                while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0  // recv not ready\n                       && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES   // not timeout\n                )\n                    ;\n            }\n            // Mask rank if timeout\n            if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {\n                printf(\"Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\\n\",\n                       rank,\n                       responsible_expert_idx % num_local_experts,\n                       src_rank);\n                if (mask_buffer_ptr == nullptr)\n                    trap();\n                atomicExch(mask_buffer_ptr + src_rank, 1);\n            }\n\n            if (combine_wait_recv_cost_stats != nullptr) {\n                atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);\n            }\n        }\n    }\n    cg::this_grid().sync();\n\n    // Reassign warp groups\n    constexpr int kMaxNumGroups = 2;\n    const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32);\n    const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1));\n    const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0);\n    const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0);\n    EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, \"Invalid vectorization\");\n    EP_DEVICE_ASSERT(num_topk <= 32);\n    EP_DEVICE_ASSERT(num_groups > 0);\n\n    if (group_idx < num_groups) {\n        constexpr int kNumStages = 3;\n        constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2;\n        constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2;\n        constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10;\n        constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t);\n        constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3;\n\n        // Reallocate shared memory\n        const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx;\n        auto full_barriers =\n            PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes); });\n        auto empty_barriers =\n            PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes + 8); });\n        auto tma_ld_buffers =\n            PatternVisitor([=](const int& i) { return reinterpret_cast<uint8_t*>(smem_group_buffer + i * kNumTMABufferBytes + 16); });\n        auto tma_st_buffers = PatternVisitor([=](const int& i) {\n            return reinterpret_cast<uint32_t*>(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes);\n        });\n\n        // Redundant when logfmt is disabled\n        const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2;\n        auto log_amax_buffers =\n            PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + i * kNumDivisionBytes); });\n        auto log_amin_buffers = PatternVisitor([=](const int& i) {\n            return reinterpret_cast<float*>(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes);\n        });\n        auto cast_info_buffers = PatternVisitor([=](const int& i) {\n            return reinterpret_cast<int*>(smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes);\n        });\n\n        uint32_t tma_phase = 0;\n        EP_STATIC_ASSERT(kNumStages < 32, \"Too many stages\");\n        if (decode_warp_idx == num_decode_warps)\n            tma_phase = (1 << kNumStages) - 1;\n\n        // Initialize m-barriers\n        if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) {\n            mbarrier_init(full_barriers[lane_id], 1);\n            mbarrier_init(empty_barriers[lane_id], num_decode_warps);\n        }\n        asm volatile(\"bar.sync %0, %1;\" ::\"r\"(group_idx + 1), \"r\"((num_decode_warps + 1) * 32));\n\n        int stage_idx = 0, topk_idx_by_lane = 0;\n        EP_STATIC_ASSERT(kNumMaxTopk <= 32, \"Invalid number of topks\");\n        if (decode_warp_idx == num_decode_warps) {\n            // TMA load warp\n            for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {\n                if (lane_id < num_topk)\n                    topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));\n                for (int i = 0; i < num_topk; ++i) {\n                    int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);\n                    if (topk_idx_reg < 0)\n                        continue;\n                    if (is_rank_masked(mask_buffer_ptr, topk_idx_reg / num_local_experts))\n                        continue;\n\n                    mbarrier_wait<true>(empty_barriers[stage_idx], tma_phase, stage_idx);\n                    auto buffer = static_cast<uint8_t*>(rdma_recv_x) +\n                        (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot;\n                    if constexpr (kUseLogFMT) {\n                        logfmt_check_amaxmin<kNumDivisions / 2, kNumSendUnrolls, kNumRecvUnrolls>(\n                            buffer,\n                            reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),\n                            reinterpret_cast<float2*>(log_amin_buffers[stage_idx]),\n                            cast_info_buffers[stage_idx],\n                            lane_id);\n                    }\n                    if (elect_one_sync()) {\n                        int num_casted = 0;\n                        if constexpr (kUseLogFMT) {\n                            const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];\n                            num_casted = (info >> 1) + (info & 1);\n                        }\n                        int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes;\n                        tma_load_1d(\n                            tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes);\n                        mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes);\n                    }\n                    __syncwarp();\n                    stage_idx = (stage_idx + 1) % kNumStages;\n                }\n            }\n        } else {\n            // Reduction warps\n            float topk_weights_by_lane;\n            for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {\n                if (lane_id < num_topk) {\n                    topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));\n                    topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);\n                }\n                __syncwarp();\n\n                float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};\n                for (int i = 0; i < num_topk; ++i) {\n                    int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);\n                    if (topk_idx_reg < 0)\n                        continue;\n                    if (is_rank_masked(mask_buffer_ptr, topk_idx_reg / num_local_experts))\n                        continue;\n                    const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i);\n\n                    mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);\n                    if constexpr (kUseLogFMT) {\n                        const auto& info = cast_info_buffers[stage_idx][decode_warp_idx];\n                        bool enable_cast = info & 1;\n                        int num_casted_prefix = info >> 1;\n                        int tma_offset =\n                            kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix);\n                        int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16;\n                        decode_and_accumulate<kNumRecvUnrolls>(\n                            reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset +\n                                                        (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id),\n                            combined_values,\n                            log_amax_buffers[stage_idx][division_idx],\n                            log_amin_buffers[stage_idx][division_idx],\n                            enable_cast,\n                            topk_weight);\n                    } else {\n                        int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx;\n                        decode_and_accumulate<kNumRecvUnrolls>(\n                            reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id),\n                            combined_values,\n                            0,\n                            0,\n                            false,\n                            topk_weight);\n                    }\n\n                    if (elect_one_sync())\n                        mbarrier_arrive(empty_barriers[stage_idx]);\n                    stage_idx = (stage_idx + 1) % kNumStages;\n                }\n                tma_store_wait<0>();\n\n                #pragma unroll\n                for (int k = 0; k < kNumRecvUnrolls * 4; ++k) {\n                    auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]);\n                    tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);\n                }\n                tma_store_fence();\n                if (elect_one_sync()) {\n                    tma_store_1d(tma_st_buffers[decode_warp_idx],\n                                 static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,\n                                 kNumBF16PerWarpBytes);\n                }\n                __syncwarp();\n            }\n        }\n    }\n}\n\nvoid combine(void* combined_x,\n             void* rdma_recv_x,\n             int* rdma_recv_flag,\n             void* rdma_send_x,\n             const void* x,\n             const topk_idx_t* topk_idx,\n             const float* topk_weights,\n             const int* src_info,\n             const int64_t* layout_range,\n             int* mask_buffer_ptr,\n             int64_t* combine_wait_recv_cost_stats,\n             int* next_clean,\n             int num_next_clean_int,\n             int num_combined_tokens,\n             int hidden,\n             int num_max_dispatch_tokens_per_rank,\n             int num_topk,\n             int num_experts,\n             int rank,\n             int num_ranks,\n             bool use_logfmt,\n             void* workspace,\n             int num_device_sms,\n             cudaStream_t stream,\n             int phases,\n             bool zero_copy) {\n    constexpr int kNumMaxTopk = 11;\n    const int num_warp_groups = ceil_div(num_experts, num_device_sms);\n    const int num_warps_per_group = 32 / num_warp_groups;\n    const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);\n    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);\n\n    const auto num_warps = num_warp_groups * num_warps_per_group;\n    const auto num_sms =\n        max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));\n\n    // Check workspace\n    auto atomic_clean_flag = static_cast<int*>(workspace);\n    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);\n    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);\n\n    // Online cast cannot use zero-copy\n    EP_HOST_ASSERT(not(zero_copy and use_logfmt));\n\n    constexpr int kNumStages = 3;\n    constexpr int kNumMaxUnrolls = 4;\n    constexpr int kMaxNumGroups = 2;\n\n    // Send buffer size\n    const int num_meta_bytes = hidden / 128 * 4;\n    const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;\n    const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);\n\n    // Receive buffer size\n    const int num_recv_tma_bytes = 16 + hidden * 2;\n    const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);\n\n    // Total requirement\n    const int smem_size = max(smem_send_size, smem_recv_size);\n\n#define COMBINE_LAUNCH_CASE(hidden)                                                                                                \\\n    {                                                                                                                              \\\n        auto combine_func =                                                                                                        \\\n            use_logfmt ? combine<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \\\n        SET_SHARED_MEMORY_FOR_TMA(combine_func);                                                                                   \\\n        LAUNCH_KERNEL(&cfg,                                                                                                        \\\n                      combine_func,                                                                                                \\\n                      combined_x,                                                                                                  \\\n                      rdma_recv_x,                                                                                                 \\\n                      rdma_recv_flag,                                                                                              \\\n                      rdma_send_x,                                                                                                 \\\n                      x,                                                                                                           \\\n                      topk_idx,                                                                                                    \\\n                      topk_weights,                                                                                                \\\n                      src_info,                                                                                                    \\\n                      layout_range,                                                                                                \\\n                      mask_buffer_ptr,                                                                                             \\\n                      combine_wait_recv_cost_stats,                                                                                \\\n                      next_clean,                                                                                                  \\\n                      num_next_clean_int,                                                                                          \\\n                      atomic_clean_flag,                                                                                           \\\n                      num_combined_tokens,                                                                                         \\\n                      hidden,                                                                                                      \\\n                      num_topk,                                                                                                    \\\n                      num_max_dispatch_tokens_per_rank,                                                                            \\\n                      num_experts,                                                                                                 \\\n                      rank,                                                                                                        \\\n                      num_ranks,                                                                                                   \\\n                      num_warp_groups,                                                                                             \\\n                      num_warps_per_group,                                                                                         \\\n                      phases,                                                                                                      \\\n                      zero_copy);                                                                                                  \\\n    }                                                                                                                              \\\n    break\n\n    SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);\n    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);\n#undef COMBINE_LAUNCH_CASE\n}\n\ntemplate <int kNumThreads>\n__launch_bounds__(kNumThreads, 1) __global__ void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor) {\n    const auto num_sms = static_cast<int>(gridDim.x);\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    const auto num_threads = num_sms * kNumThreads;\n    const auto thread_id = sm_id * kNumThreads + static_cast<int>(threadIdx.x);\n    for (int rank_id = thread_id; rank_id < num_ranks; rank_id += num_threads) {\n        mask_tensor[rank_id] = mask_buffer_ptr[rank_id];\n    }\n}\n\nvoid query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor, cudaStream_t stream) {\n    constexpr int num_sms = 1;\n    constexpr int kNumThreads = 1024;\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    LAUNCH_KERNEL(&cfg, query_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks, mask_tensor);\n}\n\ntemplate <int kNumThreads>\n__launch_bounds__(kNumThreads, 1) __global__ void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask) {\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    const auto thread_id = static_cast<int>(threadIdx.x);\n    if (sm_id == 0 && thread_id == 0) {\n        atomicExch(mask_buffer_ptr + rank_to_mask, mask ? 1 : 0);\n    }\n}\n\nvoid update_mask_buffer(int* mask_buffer_ptr, int rank, bool mask, cudaStream_t stream) {\n    constexpr int num_sms = 1;\n    constexpr int kNumThreads = 32;\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    LAUNCH_KERNEL(&cfg, update_mask_buffer<kNumThreads>, mask_buffer_ptr, rank, mask);\n}\n\ntemplate <int kNumThreads>\n__launch_bounds__(kNumThreads, 1) __global__ void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks) {\n    auto thread_id = static_cast<int>(threadIdx.x);\n    #pragma unroll\n    for (int i = thread_id; i < num_ranks; i += kNumThreads)\n        mask_buffer_ptr[i] = 0;\n}\n\nvoid clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream) {\n    constexpr int num_sms = 1;\n    constexpr int kNumThreads = 32;\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    LAUNCH_KERNEL(&cfg, clean_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks);\n}\n\n}  // namespace internode_ll\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/intranode.cu",
    "content": "#include \"buffer.cuh\"\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"launch.cuh\"\n#include \"utils.cuh\"\n\nnamespace deep_ep {\n\nnamespace intranode {\n\ntemplate <int kNumRanks>\n__global__ void notify_dispatch(const int* num_tokens_per_rank,\n                                int* moe_recv_counter_mapped,\n                                const int* num_tokens_per_expert,\n                                int* moe_recv_expert_counter_mapped,\n                                int num_experts,\n                                int num_tokens,\n                                int num_channels,\n                                const bool* is_token_in_rank,\n                                int* channel_prefix_matrix,\n                                int* rank_prefix_matrix_copy,\n                                int num_memset_int,\n                                int expert_alignment,\n                                void** buffer_ptrs,\n                                int** barrier_signal_ptrs,\n                                int rank) {\n    auto sm_id = static_cast<int>(blockIdx.x);\n    auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);\n    auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;\n\n    if (sm_id == 0) {\n        // Barrier first\n        barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);\n\n        int *per_rank_buffer, *per_expert_buffer;\n        if (thread_id < kNumRanks) {\n            per_rank_buffer = static_cast<int*>(buffer_ptrs[thread_id]);\n            per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;\n        }\n\n        // After this loop:\n        //  - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j\n        //  - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j\n        int num_experts_per_rank = num_experts / kNumRanks;\n        if (thread_id < kNumRanks) {\n            per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id];\n            #pragma unroll\n            for (int i = 0; i < num_experts_per_rank; ++i)\n                per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];\n        }\n\n        // Wait for all ranks to be finished\n        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);\n\n        // Sum per-rank counts and return to CPU\n        // Also pre-compute the prefix sum for data sending\n        auto local_per_rank_buffer = static_cast<int*>(buffer_ptrs[rank]);\n        if (thread_id < kNumRanks) {\n            #pragma unroll\n            for (int i = 1; i < kNumRanks; ++i)\n                local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];\n            if (thread_id == rank)\n                *moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];\n        }\n\n        // Sum per-experts counts and return to CPU\n        auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;\n        if (thread_id < num_experts_per_rank) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumRanks; ++i)\n                sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];\n            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;\n            moe_recv_expert_counter_mapped[thread_id] = sum;\n        }\n        __syncthreads();\n\n        // Copy rank size prefix matrix to another tensor\n        #pragma unroll\n        for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)\n            rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];\n\n        // Extra memset for later communication queue\n        #pragma unroll\n        for (int i = thread_id; i < num_memset_int; i += num_threads)\n            local_per_expert_buffer[i] = 0;\n\n        // Barrier\n        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);\n    } else {\n        int dst_rank = sm_id - 1;\n        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {\n            int token_start_idx, token_end_idx;\n            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);\n\n            // Iterate over tokens\n            int count = 0;\n            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)\n                count += is_token_in_rank[i * kNumRanks + dst_rank];\n            count = warp_reduce_sum(count);\n            if (elect_one_sync())\n                channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;\n        }\n        __syncthreads();\n\n        // Pre-compute prefix sum for all channels\n        if (thread_id == 0) {\n            #pragma unroll\n            for (int i = 1; i < num_channels; ++i)\n                channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1];\n        }\n    }\n}\n\nvoid notify_dispatch(const int* num_tokens_per_rank,\n                     int* moe_recv_counter_mapped,\n                     int num_ranks,\n                     const int* num_tokens_per_expert,\n                     int* moe_recv_expert_counter_mapped,\n                     int num_experts,\n                     int num_tokens,\n                     const bool* is_token_in_rank,\n                     int* channel_prefix_matrix,\n                     int* rank_prefix_matrix_copy,\n                     int num_memset_int,\n                     int expert_alignment,\n                     void** buffer_ptrs,\n                     int** barrier_signal_ptrs,\n                     int rank,\n                     cudaStream_t stream,\n                     int num_channels) {\n#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks)        \\\n    LAUNCH_KERNEL(&cfg,                           \\\n                  notify_dispatch<ranks>,         \\\n                  num_tokens_per_rank,            \\\n                  moe_recv_counter_mapped,        \\\n                  num_tokens_per_expert,          \\\n                  moe_recv_expert_counter_mapped, \\\n                  num_experts,                    \\\n                  num_tokens,                     \\\n                  num_channels,                   \\\n                  is_token_in_rank,               \\\n                  channel_prefix_matrix,          \\\n                  rank_prefix_matrix_copy,        \\\n                  num_memset_int,                 \\\n                  expert_alignment,               \\\n                  buffer_ptrs,                    \\\n                  barrier_signal_ptrs,            \\\n                  rank);                          \\\n    break\n\n    constexpr int kNumThreads = 128;\n    EP_HOST_ASSERT(num_experts % num_ranks == 0);\n    EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);\n\n    SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);\n    SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);\n#undef NOTIFY_DISPATCH_LAUNCH_CASE\n}\n\ntemplate <int kNumRanks>\n__global__ void cached_notify_dispatch(\n    const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {\n    // A simplified version for cached handles\n    barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);\n\n    // Copy and clean\n    auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);\n    auto ptr = static_cast<int*>(buffer_ptrs[rank]);\n    #pragma unroll\n    for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)\n        ptr[i] = rank_prefix_matrix[i];\n    #pragma unroll\n    for (int i = thread_id; i < num_memset_int; i += num_threads)\n        ptr[kNumRanks * kNumRanks + i] = 0;\n\n    // Barrier after cleaning\n    barrier_block<kNumRanks>(barrier_signal_ptrs, rank);\n}\n\nvoid cached_notify_dispatch(const int* rank_prefix_matrix,\n                            int num_memset_int,\n                            void** buffer_ptrs,\n                            int** barrier_signal_ptrs,\n                            int rank,\n                            int num_ranks,\n                            cudaStream_t stream) {\n#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks)                                                                                   \\\n    LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \\\n    break\n\n    SETUP_LAUNCH_CONFIG(1, 128, stream);\n    SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);\n#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE\n}\n\ntemplate <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>\n__global__ void __launch_bounds__(kNumThreads, 1) dispatch(int4* recv_x,\n                                                           float* recv_x_scales,\n                                                           int* recv_src_idx,\n                                                           topk_idx_t* recv_topk_idx,\n                                                           float* recv_topk_weights,\n                                                           int* recv_channel_offset,\n                                                           int* send_head,\n                                                           const int4* x,\n                                                           const float* x_scales,\n                                                           const topk_idx_t* topk_idx,\n                                                           const float* topk_weights,\n                                                           const bool* is_token_in_rank,\n                                                           const int* channel_prefix_matrix,\n                                                           int num_tokens,\n                                                           int num_worst_tokens,\n                                                           int hidden_int4,\n                                                           int num_topk,\n                                                           int num_experts,\n                                                           int num_scales,\n                                                           int scale_token_stride,\n                                                           int scale_hidden_stride,\n                                                           void** buffer_ptrs,\n                                                           int rank,\n                                                           int num_max_send_tokens,\n                                                           int num_recv_buffer_tokens) {\n    const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);\n    const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();\n    const bool is_sender = sm_id % 2 == 0;\n    EP_DEVICE_ASSERT(num_sms % 2 == 0);\n\n    // Several warps are response for a single rank\n    const auto num_threads_per_rank = kNumThreads / kNumRanks;\n    const auto num_channels = num_sms / 2;\n    const auto responsible_rank = (static_cast<int>(thread_id)) / num_threads_per_rank;\n    // Even-numbered blocks for sending, odd-numbered blocks for receiving.\n    const auto responsible_channel = sm_id / 2;\n\n    int num_experts_per_rank = num_experts / kNumRanks;\n    EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);\n    EP_DEVICE_ASSERT(num_topk <= 32);\n    EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));\n    EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));\n\n    // Calculate pointers by the specific layout\n    // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)\n    auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) +\n                                       kNumRanks * kNumRanks * sizeof(int));\n    int target_rank = is_sender ? rank : responsible_rank;\n    auto num_channels_total = num_channels * kNumRanks;\n    auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;\n\n    // Channel buffer metadata\n    // Senders are responsible for tails, and receivers are responsible for heads\n    // Stored on the receiver side\n    // The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t`\n    // `start_offset`: kNumChannels * kNumRanks * sizeof(int)\n    // `end_offset`: kNumChannels * kNumRanks * sizeof(int)\n    // `head_idx`: kNumChannels * kNumRanks * sizeof(int)\n    // `tail_idx`: kNumChannels * kNumRanks * sizeof(int)\n    auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n    auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n    auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n    auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n\n    // Channel data buffers, stored on the receiver side\n    // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)\n    // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)\n    // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(topk_idx_t)\n    // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)\n    // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)\n    auto channel_x_buffers = Buffer<int4>(\n        ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);\n    auto channel_src_idx_buffers =\n        Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);\n    auto channel_topk_idx_buffers = Buffer<topk_idx_t>(\n        ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);\n    auto channel_topk_weights_buffers =\n        Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);\n    auto channel_x_scales_buffers = Buffer<float>(\n        ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);\n\n    // TMA stuffs\n#ifndef DISABLE_SM90_FEATURES\n    extern __shared__ __align__(1024) uint8_t smem_buffer[];\n    auto half_hidden_int4 = hidden_int4 / 2;\n    auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));\n    auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;\n    auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);\n    uint32_t tma_phase = 0;\n    if (elect_one_sync()) {\n        mbarrier_init(tma_mbarrier, 1);\n        fence_barrier_init();\n        EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);\n    }\n    __syncwarp();\n#endif\n\n    if (is_sender) {\n        // Workers for sending\n        constexpr int num_send_warps = kNumThreads / 32;\n        constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;\n        const auto send_thread_id = thread_id;\n        const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;\n        EP_DEVICE_ASSERT(kNumRanks <= 32);\n        EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);\n\n        // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2\n        // NOTES: this is for distinguishing zero tokens\n        if (send_warp_id_in_rank == 0 and elect_one_sync()) {\n            int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;\n            st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);\n            value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];\n            st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);\n        }\n        __syncwarp();\n\n        // Get tasks\n        int token_start_idx, token_end_idx;\n        get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);\n\n        // Iterate over all tokens and send by chunks\n        int cached_channel_tail_idx = 0;\n        for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {\n            // Check destination queue emptiness, or wait a buffer to be released (rare cases)\n            // NOTES: the head index received by different warps may not be the same\n            auto start_time = clock64();\n            if (elect_one_sync()) {\n                while (true) {\n                    // NOTES: we only consider the worst case, because counting the real numbers are time-consuming\n                    int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());\n                    if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)\n                        break;\n\n                    // Rare cases to loop again\n                    if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                        printf(\"DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\\n\", rank, responsible_channel);\n                        trap();\n                    }\n                }\n            }\n            __syncwarp();\n\n            int chunk_token_idx = 0;\n            while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {\n                // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the\n                // following data\n                if (token_idx % num_send_warps_per_rank == send_warp_id_in_rank and elect_one_sync())\n                    send_head[token_idx * kNumRanks + responsible_rank] =\n                        is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;\n\n                // Skip if not selected\n                if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) {\n                    token_idx++;\n                    continue;\n                }\n\n                // Get an empty slot\n                int dst_slot_idx = (cached_channel_tail_idx++) % num_recv_buffer_tokens;\n                if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) {\n                    // Copy data\n                    auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;\n                    auto shifted_x = x + token_idx * hidden_int4;\n                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);\n\n                    // Copy source index\n                    if (elect_one_sync())\n                        channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);\n\n                    // Copy `topk_idx` and `topk_weights` with transformed index\n                    if (lane_id < num_topk) {\n                        // Top-k index\n                        int recv_expert_begin = responsible_rank * num_experts_per_rank,\n                            recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;\n                        auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);\n                        idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;\n                        channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;\n\n                        // Top-k weights\n                        auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);\n                        weight_value = (idx_value >= 0) ? weight_value : 0.0f;\n                        channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value;\n                    }\n\n                    // Copy `x_scales`\n                    #pragma unroll\n                    for (int i = lane_id; i < num_scales; i += 32) {\n                        auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;\n                        channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);\n                    }\n                }\n\n                // Move token index\n                chunk_token_idx++, token_idx++;\n            }\n\n            // Move tail index\n            // NOTES: here all warps should share the same new tail\n            asm volatile(\"bar.sync %0, %1;\" ::\"r\"(responsible_rank), \"r\"(num_threads_per_rank));\n            if (send_warp_id_in_rank == 0 and elect_one_sync())\n                st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);\n        }\n    } else {\n        // Workers for receiving and copying into buffer\n        constexpr int num_recv_warps = kNumThreads / 32;\n        constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;\n        const auto recv_thread_id = thread_id;\n        const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;\n        const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;\n        EP_DEVICE_ASSERT(kNumRanks <= 32);\n        EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);\n\n        // Calculate offset first\n        auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);\n        int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;\n\n        // Receive channel offset\n        int total_offset, num_tokens_to_recv;\n        if (elect_one_sync()) {\n            while ((total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0)\n                ;\n            while ((num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0)\n                ;\n            total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;\n            if (recv_warp_id_in_rank == 0)\n                recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;\n            num_tokens_to_recv -= total_offset;\n        }\n        total_offset = __shfl_sync(0xffffffff, total_offset, 0);\n        total_offset += rank_offset;\n        num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);\n\n        // Shared tail indices for different warps\n        __shared__ volatile int shared_channel_tail_idx[kNumRanks];\n\n        auto start_time = clock64();\n        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;\n        while (num_tokens_to_recv > 0) {\n            // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same\n            while (recv_thread_id_in_rank == 0) {\n                cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());\n\n                // Ready to copy\n                if (cached_channel_head_idx != cached_channel_tail_idx) {\n                    shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx;\n                    break;\n                }\n\n                // Timeout check\n                if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                    printf(\"DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\\n\",\n                           rank,\n                           responsible_channel,\n                           num_tokens_to_recv);\n                    trap();\n                }\n            }\n\n            // Synchronize queue tail\n            asm volatile(\"bar.sync %0, %1;\" ::\"r\"(responsible_rank), \"r\"(num_threads_per_rank));\n            cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank];\n\n            // Copy data\n            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;\n            for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) {\n                int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;\n                auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;\n                auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;\n#ifndef DISABLE_SM90_FEATURES\n                #pragma unroll\n                for (int i = 0; i < 2; ++i) {\n                    tma_store_wait<0>();\n                    if (elect_one_sync()) {\n                        tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);\n                        mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);\n                        mbarrier_wait(tma_mbarrier, tma_phase);\n                        tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);\n                    }\n                }\n                __syncwarp();\n#else\n                UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, ld_nc_global, st_na_global);\n#endif\n            }\n\n            // Copy `src_idx`\n            #pragma unroll 4\n            for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx;\n                 chunk_idx += 32 * num_recv_warps_per_rank)\n                recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] =\n                    ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);\n\n            // Copy `topk_idx` and `topk_weights`\n            #pragma unroll 4\n            for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) {\n                int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;\n                int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;\n                auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;\n                auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;\n                recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);\n                recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);\n            }\n\n            // Copy `x_scales`\n            #pragma unroll 4\n            for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) {\n                int chunk_idx = i / num_scales, scales_idx = i % num_scales;\n                int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;\n                recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =\n                    ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx);\n            }\n\n            // Move queue\n            cached_channel_head_idx += num_recv_tokens;\n            total_offset += num_recv_tokens;\n            asm volatile(\"bar.sync %0, %1;\" ::\"r\"(responsible_rank), \"r\"(num_threads_per_rank));\n            if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and elect_one_sync())\n                st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);\n\n            // Exit\n            num_tokens_to_recv -= num_recv_tokens;\n        }\n    }\n\n    // Clean unused `recv_topk_idx` as -1\n    if (num_worst_tokens > 0) {\n        auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);\n        const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];\n        const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;\n        const auto clean_end = num_worst_tokens * num_topk;\n        const auto clean_stride = num_sms * kNumThreads;\n        #pragma unroll\n        for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)\n            recv_topk_idx[i] = -1;\n    }\n}\n\nvoid dispatch(void* recv_x,\n              float* recv_x_scales,\n              int* recv_src_idx,\n              topk_idx_t* recv_topk_idx,\n              float* recv_topk_weights,\n              int* recv_channel_offset,\n              int* send_head,\n              const void* x,\n              const float* x_scales,\n              const topk_idx_t* topk_idx,\n              const float* topk_weights,\n              const bool* is_token_in_rank,\n              const int* channel_prefix_matrix,\n              int num_tokens,\n              int num_worst_tokens,\n              int hidden_int4,\n              int num_topk,\n              int num_experts,\n              int num_scales,\n              int scale_token_stride,\n              int scale_hidden_stride,\n              void** buffer_ptrs,\n              int rank,\n              int num_ranks,\n              cudaStream_t stream,\n              int num_sms,\n              int num_max_send_tokens,\n              int num_recv_buffer_tokens) {\n    constexpr int kNumThreads = 768;\n    constexpr int kNumTMABytesPerWarp = 8192;\n#ifndef DISABLE_SM90_FEATURES\n    constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);\n#endif\n\n    // Make sure never OOB\n    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());\n\n#define DISPATCH_LAUNCH_CASE(ranks)                                      \\\n    {                                                                    \\\n        auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \\\n        SET_SHARED_MEMORY_FOR_TMA(kernel);                               \\\n        LAUNCH_KERNEL(&cfg,                                              \\\n                      kernel,                                            \\\n                      reinterpret_cast<int4*>(recv_x),                   \\\n                      recv_x_scales,                                     \\\n                      recv_src_idx,                                      \\\n                      recv_topk_idx,                                     \\\n                      recv_topk_weights,                                 \\\n                      recv_channel_offset,                               \\\n                      send_head,                                         \\\n                      reinterpret_cast<const int4*>(x),                  \\\n                      x_scales,                                          \\\n                      topk_idx,                                          \\\n                      topk_weights,                                      \\\n                      is_token_in_rank,                                  \\\n                      channel_prefix_matrix,                             \\\n                      num_tokens,                                        \\\n                      num_worst_tokens,                                  \\\n                      hidden_int4,                                       \\\n                      num_topk,                                          \\\n                      num_experts,                                       \\\n                      num_scales,                                        \\\n                      scale_token_stride,                                \\\n                      scale_hidden_stride,                               \\\n                      buffer_ptrs,                                       \\\n                      rank,                                              \\\n                      num_max_send_tokens,                               \\\n                      num_recv_buffer_tokens);                           \\\n    }                                                                    \\\n    break\n\n    // Even-numbered blocks for sending, odd-numbered blocks for receiving.\n    EP_HOST_ASSERT(num_sms % 2 == 0);\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    SWITCH_RANKS(DISPATCH_LAUNCH_CASE);\n#undef DISPATCH_LAUNCH_CASE\n}\n\ntemplate <int kNumRanks>\n__global__ void cached_notify_combine(\n    void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, int** barrier_signal_ptrs, int rank) {\n    const auto sm_id = static_cast<int>(blockIdx.x);\n    if (sm_id == 0) {\n        // Barrier before cleaning\n        barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);\n\n        // Clean\n        auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);\n        auto ptr = static_cast<int*>(buffer_ptrs[rank]);\n        #pragma unroll\n        for (int i = thread_id; i < num_memset_int; i += num_threads)\n            ptr[i] = 0;\n\n        // Barrier after cleaning\n        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);\n    } else {\n        const auto channel_id = sm_id - 1;\n        const auto thread_id = static_cast<int>(threadIdx.x);\n        const auto rank_id = thread_id / 32;\n        const auto lane_id = thread_id % 32;\n        if (rank_id >= kNumRanks)\n            return;\n\n        int token_start_idx, token_end_idx;\n        get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);\n\n        // NOTES: `1 << 25` is a heuristic large number\n        int last_head = 1 << 25;\n        #pragma unroll\n        for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) {\n            int token_idx = token_idx_tail - lane_id, expected_head = 0;\n            auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;\n            for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++i) {\n                const int head = __shfl_sync(0xffffffff, current_head, i);\n                if (head < 0) {\n                    if (lane_id == i)\n                        expected_head = -last_head - 1;\n                } else {\n                    last_head = head;\n                }\n            }\n            if (current_head < 0 and token_idx >= token_start_idx)\n                send_head[token_idx * kNumRanks + rank_id] = expected_head;\n        }\n    }\n}\n\nvoid cached_notify_combine(void** buffer_ptrs,\n                           int* send_head,\n                           int num_channels,\n                           int num_recv_tokens,\n                           int num_memset_int,\n                           int** barrier_signal_ptrs,\n                           int rank,\n                           int num_ranks,\n                           cudaStream_t stream) {\n#define CACHED_NOTIFY_COMBINE(ranks)            \\\n    LAUNCH_KERNEL(&cfg,                         \\\n                  cached_notify_combine<ranks>, \\\n                  buffer_ptrs,                  \\\n                  send_head,                    \\\n                  num_channels,                 \\\n                  num_recv_tokens,              \\\n                  num_memset_int,               \\\n                  barrier_signal_ptrs,          \\\n                  rank);                        \\\n    break\n\n    const int num_threads = std::max(128, 32 * num_ranks);\n    EP_HOST_ASSERT(num_ranks <= num_threads);\n    EP_HOST_ASSERT(num_threads <= 1024);\n    EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);\n    SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);\n    SWITCH_RANKS(CACHED_NOTIFY_COMBINE);\n#undef CACHED_NOTIFY_COMBINE\n}\n\ntemplate <typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>\n__global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x,\n                                                          float* recv_topk_weights,\n                                                          const dtype_t* x,\n                                                          const float* topk_weights,\n                                                          const dtype_t* bias_0,\n                                                          const dtype_t* bias_1,\n                                                          const int* src_idx,\n                                                          const int* rank_prefix_matrix,\n                                                          const int* channel_prefix_matrix,\n                                                          int* send_head,\n                                                          int num_tokens,\n                                                          int num_recv_tokens,\n                                                          int hidden,\n                                                          int num_topk,\n                                                          void** buffer_ptrs,\n                                                          int rank,\n                                                          int num_max_send_tokens,\n                                                          int num_recv_buffer_tokens) {\n    const auto num_sms = static_cast<int>(gridDim.x);\n    const auto thread_id = static_cast<int>(threadIdx.x);\n    const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();\n    const auto num_channels = num_sms / 2;\n    const bool is_sender = sm_id % 2 == 0;\n    const int responsible_channel = sm_id / 2;\n    EP_DEVICE_ASSERT(num_topk <= 32);\n\n    constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);\n    int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);\n    int hidden_int4_aligned = align_down(hidden_int4, 32);\n    auto x_int4 = reinterpret_cast<const int4*>(x);\n    auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);\n    auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);\n    auto recv_int4 = reinterpret_cast<int4*>(recv_x);\n\n    // TMA stuffs\n#ifndef DISABLE_SM90_FEATURES\n    extern __shared__ __align__(1024) uint8_t smem_buffer[];\n    auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;\n#endif\n\n    if (is_sender) {\n        // Workers for sending\n        // Several warps are responsible for a single rank\n        constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks;\n        constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;\n        const auto num_threads_per_rank = num_send_warps_per_rank * 32;\n        const auto send_thread_id = thread_id;\n        const auto send_warp_id = send_thread_id / 32;\n        const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks;\n        const auto send_warp_id_in_rank = send_warp_id / kNumRanks;\n        EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, \"Invalid warp count\");\n\n        // Calculate pointers by the specific layout\n        auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[send_rank_id]));\n        auto num_channels_total = num_channels * kNumRanks;\n        auto channel_rank_offset = responsible_channel * kNumRanks + rank;\n\n        // Channel meta data\n        // `head_idx`: kNumChannels * kNumRanks * sizeof(int)\n        // `tail_idx`: kNumChannels * kNumRanks * sizeof(int)\n        // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)\n        // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)\n        // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)\n        auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n        auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);\n        auto channel_x_buffers = Buffer<int4>(\n            ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);\n        auto channel_src_idx_buffers =\n            Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);\n        auto channel_topk_weights_buffers = Buffer<float>(\n            ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);\n\n        // Get tasks\n        // NOTES: `channel_offset` is already shifted\n        int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;\n        int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;\n        int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];\n        int num_channel_tokens =\n            (responsible_channel == num_channels - 1 ? num_rank_tokens\n                                                     : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) -\n            channel_offset;\n        int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens;\n\n        // Iterate over all tokens and send by chunks\n        int current_channel_tail_idx = 0;\n        for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {\n            // Check destination queue emptiness, or wait a buffer to be released (rare cases)\n            auto start_time = clock64();\n            int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));\n            if (elect_one_sync()) {\n                while (true) {\n                    // NOTES: we only consider the worst case, because counting the real numbers are time-consuming\n                    int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());\n                    if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)\n                        break;\n\n                    // Rare cases to loop again\n                    if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                        printf(\"DeepEP timeout for combine senders, rank %d, responsible_channel = %d\\n\", rank, responsible_channel);\n                        trap();\n                    }\n                }\n            }\n            __syncwarp();\n\n            // Send by chunk\n            #pragma unroll\n            for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {\n                // Get an empty slot\n                int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;\n\n                // Copy data\n                auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;\n                auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;\n                UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);\n\n                // Send source index\n                if (elect_one_sync())\n                    channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);\n\n                // Send `topk_weights`\n                if (num_topk > 0 and lane_id < num_topk)\n                    channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] =\n                        __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);\n            }\n            token_idx += num_round_tokens;\n            current_channel_tail_idx += num_round_tokens;\n\n            // Move tail index\n            asm volatile(\"bar.sync %0, %1;\" ::\"r\"(send_rank_id), \"r\"(num_threads_per_rank));\n            if (send_warp_id_in_rank == 0 and elect_one_sync())\n                st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);\n        }\n    } else {\n        // Workers for receiving\n        // One warp for moving the queue head, others for reduction\n        constexpr int num_recv_warps = kNumThreads / 32;\n        const auto recv_warp_id = thread_id / 32;\n        EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);\n        EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);\n\n        // Shared head, tail and retired flags for receiver warps\n        __shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];\n        __shared__ volatile int channel_tail_idx[kNumRanks];\n        __shared__ volatile bool warp_retired[num_recv_warps];\n        if (thread_id < num_recv_warps)\n            warp_retired[thread_id] = false;\n        if (lane_id < kNumRanks)\n            warp_channel_head_idx[recv_warp_id][lane_id] = 0;\n        if (thread_id < kNumRanks)\n            channel_tail_idx[thread_id] = 0;\n        asm volatile(\"bar.sync 0, %0;\" ::\"r\"(kNumThreads));\n\n        if (thread_id < 32) {\n            int* channel_head_idx_ptr = static_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;\n            int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;\n\n            // Queue head updater\n            int last_head = 0;\n            while (lane_id < kNumRanks) {\n                // Check retired\n                bool retired = true;\n                #pragma unroll\n                for (int i = 1; i < num_recv_warps; ++i)\n                    retired = retired and warp_retired[i];\n                if (retired)\n                    break;\n\n                // Update queue tail\n                channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);\n\n                // Update minimum head\n                int min_head = std::numeric_limits<int>::max();\n                #pragma unroll\n                for (int i = 1; i < num_recv_warps; ++i)\n                    if (not warp_retired[i])\n                        min_head = min(min_head, warp_channel_head_idx[i][lane_id]);\n                if (min_head != std::numeric_limits<int>::max() and min_head > last_head)\n                    st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);\n            }\n        } else {\n            // Receivers\n            // Channel metadata\n            // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`\n            Buffer<int4> channel_x_buffers[kNumRanks];\n            Buffer<float> channel_topk_weights_buffers[kNumRanks];\n\n            // Calculate pointers by the specific layout\n            #pragma unroll\n            for (int i = 0; i < kNumRanks; ++i) {\n                auto channel_rank_offset = responsible_channel * kNumRanks + i;\n                auto num_channels_total = num_channels * kNumRanks;\n                // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)\n                auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));\n\n                // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)\n                channel_x_buffers[i] = Buffer<int4>(ptr,\n                                                    num_channels_total * num_recv_buffer_tokens * hidden_int4,\n                                                    channel_rank_offset * num_recv_buffer_tokens * hidden_int4);\n\n                // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)\n                ptr = reinterpret_cast<void*>(static_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));\n\n                // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)\n                channel_topk_weights_buffers[i] = Buffer<float>(\n                    ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);\n            }\n\n            // The same tokens as the dispatch process\n            int token_start_idx, token_end_idx;\n            get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);\n\n            // Iterate over all tokens and combine\n            for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {\n                // Read expected head\n                int expected_head = -1;\n                if (lane_id < kNumRanks)\n                    expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);\n\n                auto start_time = clock64();\n                while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {\n                    // Timeout check\n                    if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {\n                        printf(\"DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\\n\",\n                               rank,\n                               responsible_channel,\n                               expected_head);\n                        trap();\n                    }\n                }\n                __syncwarp();\n\n                // Broadcast current heads\n                int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];\n                #pragma unroll\n                for (int i = 0; i < kNumRanks; ++i) {\n                    auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i);\n                    if (expected_head_i >= 0) {\n                        slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;\n                        topk_ranks[num_topk_ranks++] = i;\n                    }\n                }\n\n                // Wait shared memory release\n#ifndef DISABLE_SM90_FEATURES\n                tma_store_wait<0>();\n                __syncwarp();\n#endif\n\n                // Reduce data with pipeline\n                constexpr int kNumStages = 8;\n                EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, \"Invalid count\");\n                #pragma unroll\n                for (int i = lane_id; i < hidden_int4; i += 32) {\n                    // Read bias\n                    // TODO: make it as a template\n                    int4 bias_0_value_int4 =\n                        bias_0_int4 != nullptr ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);\n                    int4 bias_1_value_int4 =\n                        bias_1_int4 != nullptr ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);\n\n                    // Read buffers\n                    int4 recv_value_int4[kNumRanks];\n                    #pragma unroll\n                    for (int j = 0; j < num_topk_ranks; ++j)\n                        recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);\n\n                    // Reduce bias\n                    float values[kDtypePerInt4];\n                    auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);\n                    auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);\n                    #pragma unroll\n                    for (int j = 0; j < kDtypePerInt4; ++j)\n                        values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);\n\n                    // Reduce all-to-all results\n                    #pragma unroll\n                    for (int j = 0; j < num_topk_ranks; ++j) {\n                        auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);\n                        #pragma unroll\n                        for (int k = 0; k < kDtypePerInt4; ++k)\n                            values[k] += static_cast<float>(recv_value_dtypes[k]);\n                    }\n\n                    // Cast back to `dtype_t`\n                    int4 out_int4;\n                    auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);\n                    #pragma unroll\n                    for (int j = 0; j < kDtypePerInt4; ++j)\n                        out_dtypes[j] = static_cast<dtype_t>(values[j]);\n\n#ifndef DISABLE_SM90_FEATURES\n                    if (i < hidden_int4_aligned) {\n                        // Wait TMA arrival\n                        tma_store_wait<kNumStages - 1>();\n                        __syncwarp();\n\n                        // Write into TMA buffer\n                        auto tma_stage_idx = (i / 32) % kNumStages;\n                        reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;\n\n                        // Issue TMA\n                        tma_store_fence();\n                        __syncwarp();\n                        if (elect_one_sync()) {\n                            auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));\n                            tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,\n                                         recv_int4 + token_idx * hidden_int4 + i,\n                                         tma_bytes,\n                                         false);\n                        }\n                        __syncwarp();\n                    } else {\n#endif\n                        recv_int4[token_idx * hidden_int4 + i] = out_int4;\n#ifndef DISABLE_SM90_FEATURES\n                    }\n#endif\n                }\n\n                // Reduce `topk_weights`\n                if (lane_id < num_topk) {\n                    float value = 0;\n                    #pragma unroll\n                    for (int i = 0; i < num_topk_ranks; ++i)\n                        value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id);\n                    recv_topk_weights[token_idx * num_topk + lane_id] = value;\n                }\n\n                // Update head\n                if (lane_id < kNumRanks)\n                    warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;\n            }\n\n            // Retired\n            __syncwarp();\n            if (elect_one_sync())\n                warp_retired[recv_warp_id] = true;\n        }\n    }\n}\n\nvoid combine(cudaDataType_t type,\n             void* recv_x,\n             float* recv_topk_weights,\n             const void* x,\n             const float* topk_weights,\n             const void* bias_0,\n             const void* bias_1,\n             const int* src_idx,\n             const int* rank_prefix_matrix,\n             const int* channel_prefix_matrix,\n             int* send_head,\n             int num_tokens,\n             int num_recv_tokens,\n             int hidden,\n             int num_topk,\n             void** buffer_ptrs,\n             int rank,\n             int num_ranks,\n             cudaStream_t stream,\n             int num_sms,\n             int num_max_send_tokens,\n             int num_recv_buffer_tokens) {\n    constexpr int kNumThreads = 768;\n    constexpr int kNumTMABytesPerWarp = 4096;\n#ifndef DISABLE_SM90_FEATURES\n    constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);\n#endif\n\n#define COMBINE_LAUNCH_CASE(dtype, ranks)                                      \\\n    {                                                                          \\\n        auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \\\n        SET_SHARED_MEMORY_FOR_TMA(kernel);                                     \\\n        LAUNCH_KERNEL(&cfg,                                                    \\\n                      kernel,                                                  \\\n                      reinterpret_cast<dtype*>(recv_x),                        \\\n                      recv_topk_weights,                                       \\\n                      reinterpret_cast<const dtype*>(x),                       \\\n                      topk_weights,                                            \\\n                      reinterpret_cast<const dtype*>(bias_0),                  \\\n                      reinterpret_cast<const dtype*>(bias_1),                  \\\n                      src_idx,                                                 \\\n                      rank_prefix_matrix,                                      \\\n                      channel_prefix_matrix,                                   \\\n                      send_head,                                               \\\n                      num_tokens,                                              \\\n                      num_recv_tokens,                                         \\\n                      hidden,                                                  \\\n                      num_topk,                                                \\\n                      buffer_ptrs,                                             \\\n                      rank,                                                    \\\n                      num_max_send_tokens,                                     \\\n                      num_recv_buffer_tokens);                                 \\\n    }                                                                          \\\n    break\n#define COMBINE_DTYPE_LAUNCH_CASE(dtype)                 \\\n    SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); \\\n    break\n\n    // Even-numbered blocks for sending, odd-numbered blocks for receiving\n    EP_HOST_ASSERT(num_sms % 2 == 0);\n    EP_HOST_ASSERT(kNumThreads >= num_ranks * 32);\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);\n#undef COMBINE_DTYPE_LAUNCH_CASE\n#undef COMBINE_LAUNCH_CASE\n}\n\n}  // namespace intranode\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/launch.cuh",
    "content": "#pragma once\n\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n\n#ifndef SETUP_LAUNCH_CONFIG\n#ifndef DISABLE_SM90_FEATURES\n#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream)                       \\\n    cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \\\n    cudaLaunchAttribute attr[2];                                                \\\n    attr[0].id = cudaLaunchAttributeCooperative;                                \\\n    attr[0].val.cooperative = 1;                                                \\\n    attr[1].id = cudaLaunchAttributeClusterDimension;                           \\\n    attr[1].val.clusterDim.x = (num_sms % 2 == 0 ? 2 : 1);                      \\\n    attr[1].val.clusterDim.y = 1;                                               \\\n    attr[1].val.clusterDim.z = 1;                                               \\\n    cfg.attrs = attr;                                                           \\\n    cfg.numAttrs = 2\n#else\n#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \\\n    int __num_sms = (sms);                        \\\n    int __num_threads = (threads);                \\\n    auto __stream = (stream)\n#endif\n#endif\n\n#ifndef LAUNCH_KERNEL\n#ifndef DISABLE_SM90_FEATURES\n#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))\n#else\n#define LAUNCH_KERNEL(config, kernel, ...)                                                 \\\n    do {                                                                                   \\\n        kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__);                    \\\n        cudaError_t e = cudaGetLastError();                                                \\\n        if (e != cudaSuccess) {                                                            \\\n            EPException cuda_exception(\"CUDA\", __FILE__, __LINE__, cudaGetErrorString(e)); \\\n            fprintf(stderr, \"%s\\n\", cuda_exception.what());                                \\\n            throw cuda_exception;                                                          \\\n        }                                                                                  \\\n    } while (0)\n#endif\n#endif\n\n#ifndef SET_SHARED_MEMORY_FOR_TMA\n#ifndef DISABLE_SM90_FEATURES\n#define SET_SHARED_MEMORY_FOR_TMA(kernel)                                                                                \\\n    EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \\\n    cfg.dynamicSmemBytes = smem_size;\n#else\n#define SET_SHARED_MEMORY_FOR_TMA(kernel) void()\n#endif\n#endif\n\n#define SWITCH_RANKS(case_macro)                           \\\n    switch (num_ranks) {                                   \\\n        case 2:                                            \\\n            case_macro(2);                                 \\\n        case 4:                                            \\\n            case_macro(4);                                 \\\n        case 8:                                            \\\n            case_macro(8);                                 \\\n        default:                                           \\\n            EP_HOST_ASSERT(false and \"Unsupported ranks\"); \\\n    }                                                      \\\n    while (false)\n\n#define SWITCH_RDMA_RANKS(case_macro)                           \\\n    switch (num_ranks / NUM_MAX_NVL_PEERS) {                    \\\n        case 2:                                                 \\\n            case_macro(2);                                      \\\n        case 3:                                                 \\\n            case_macro(3);                                      \\\n        case 4:                                                 \\\n            case_macro(4);                                      \\\n        case 6:                                                 \\\n            case_macro(6);                                      \\\n        case 8:                                                 \\\n            case_macro(8);                                      \\\n        case 12:                                                \\\n            case_macro(12);                                     \\\n        case 16:                                                \\\n            case_macro(16);                                     \\\n        case 18:                                                \\\n            case_macro(18);                                     \\\n        case 20:                                                \\\n            case_macro(20);                                     \\\n        default:                                                \\\n            EP_HOST_ASSERT(false and \"Unsupported RDMA ranks\"); \\\n    }                                                           \\\n    while (false)\n\n#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro)         \\\n    switch (num_ranks) {                                   \\\n        case 2:                                            \\\n            case_macro(dtype, 2);                          \\\n        case 4:                                            \\\n            case_macro(dtype, 4);                          \\\n        case 8:                                            \\\n            case_macro(dtype, 8);                          \\\n        default:                                           \\\n            EP_HOST_ASSERT(false and \"Unsupported ranks\"); \\\n    }                                                      \\\n    while (false)\n\n#define SWITCH_TYPES(case_macro)                          \\\n    switch (type) {                                       \\\n        case CUDA_R_16BF:                                 \\\n            case_macro(nv_bfloat16);                      \\\n        default:                                          \\\n            EP_HOST_ASSERT(false and \"Unsupported type\"); \\\n    }                                                     \\\n    while (false)\n\n#define SWITCH_HIDDEN(case_macro)                           \\\n    switch (hidden) {                                       \\\n        case 2048:                                          \\\n            case_macro(2048);                               \\\n        case 2560:                                          \\\n            case_macro(2560);                               \\\n        case 3072:                                          \\\n            case_macro(3072); /* for gpt-oss */             \\\n        case 4096:                                          \\\n            case_macro(4096);                               \\\n        case 5120:                                          \\\n            case_macro(5120);                               \\\n        case 6144:                                          \\\n            case_macro(6144); /* For qwen3 coder */         \\\n        case 7168:                                          \\\n            case_macro(7168);                               \\\n        case 8192:                                          \\\n            case_macro(8192);                               \\\n        default:                                            \\\n            EP_HOST_ASSERT(false and \"Unsupported hidden\"); \\\n    }                                                       \\\n    while (false)\n"
  },
  {
    "path": "csrc/kernels/layout.cu",
    "content": "#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"launch.cuh\"\n\nnamespace deep_ep {\n\nnamespace layout {\n\ntemplate <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>\n__global__ void get_dispatch_layout(const topk_idx_t* topk_idx,\n                                    int* num_tokens_per_rank,\n                                    int* num_tokens_per_rdma_rank,\n                                    int* num_tokens_per_expert,\n                                    bool* is_token_in_rank,\n                                    int num_tokens,\n                                    int num_topk,\n                                    int num_ranks,\n                                    int num_experts) {\n    auto sm_id = static_cast<int>(blockIdx.x);\n    auto thread_id = static_cast<int>(threadIdx.x);\n\n    // Count expert statistics\n    __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];\n    int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);\n    if (expert_begin_idx < expert_end_idx) {\n        // Per-thread count\n        #pragma unroll\n        for (int i = 0; i < kNumExpertsPerSM; ++i)\n            num_tokens_per_expert_per_thread[thread_id][i] = 0;\n        #pragma unroll\n        for (int i = thread_id; i < num_tokens; i += kNumThreads) {\n            auto shifted_topk_idx = topk_idx + i * num_topk;\n            #pragma unroll\n            for (int j = 0, expert_idx; j < num_topk; ++j) {\n                expert_idx = static_cast<int>(shifted_topk_idx[j]);\n                if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)\n                    ++num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];\n            }\n        }\n        __syncthreads();\n\n        // Sum up\n        EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, \"Too many experts per SM\");\n        if (expert_begin_idx + thread_id < expert_end_idx) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumThreads; ++i)\n                sum += num_tokens_per_expert_per_thread[i][thread_id];\n            num_tokens_per_expert[expert_begin_idx + thread_id] = sum;\n        }\n        return;\n    }\n\n    if (num_tokens_per_rdma_rank != nullptr)\n        EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);\n\n    // Count rank statistics\n    constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;\n    __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];\n    __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];\n    auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;\n    int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);\n    int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;\n    if (rank_begin_idx < rank_end_idx) {\n        const auto num_expert_per_rank = num_experts / num_ranks;\n        auto expert_begin = rank_begin_idx * num_expert_per_rank;\n        auto expert_end = rank_end_idx * num_expert_per_rank;\n\n        // Per-thread count\n        #pragma unroll\n        for (int i = 0; i < kNumRanksPerSM; ++i)\n            num_tokens_per_rank_per_thread[thread_id][i] = 0;\n        #pragma unroll\n        for (int i = 0; i < kNumRDMARanksPerSM; ++i)\n            num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;\n        #pragma unroll\n        for (int i = thread_id; i < num_tokens; i += kNumThreads) {\n            auto shifted_topk_idx = topk_idx + i * num_topk;\n            int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};\n            #pragma unroll\n            for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {\n                expert_idx = static_cast<int>(shifted_topk_idx[j]);\n                if (expert_begin <= expert_idx and expert_idx < expert_end) {\n                    // Count single rank\n                    rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;\n                    is_in_rank[rank_idx]++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++;\n                }\n            }\n\n            auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;\n            #pragma unroll\n            for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) {\n                shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);\n                num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);\n            }\n\n            #pragma unroll\n            for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j)\n                num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);\n        }\n        __syncthreads();\n\n        // Sum up\n        EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, \"Too many ranks per SM\");\n        if (rank_begin_idx + thread_id < rank_end_idx) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumThreads; ++i)\n                sum += num_tokens_per_rank_per_thread[i][thread_id];\n            num_tokens_per_rank[rank_begin_idx + thread_id] = sum;\n        }\n\n        if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {\n            int sum = 0;\n            #pragma unroll\n            for (int i = 0; i < kNumThreads; ++i)\n                sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];\n            num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;\n        }\n    }\n}\n\nvoid get_dispatch_layout(const topk_idx_t* topk_idx,\n                         int* num_tokens_per_rank,\n                         int* num_tokens_per_rdma_rank,\n                         int* num_tokens_per_expert,\n                         bool* is_token_in_rank,\n                         int num_tokens,\n                         int num_topk,\n                         int num_ranks,\n                         int num_experts,\n                         cudaStream_t stream) {\n    constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8;\n    int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;\n    EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, \"Invalid number of ranks per SM\");\n\n    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);\n    LAUNCH_KERNEL(&cfg,\n                  (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),\n                  topk_idx,\n                  num_tokens_per_rank,\n                  num_tokens_per_rdma_rank,\n                  num_tokens_per_expert,\n                  is_token_in_rank,\n                  num_tokens,\n                  num_topk,\n                  num_ranks,\n                  num_experts);\n}\n\n}  // namespace layout\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/runtime.cu",
    "content": "#include <cstring>\n#include <vector>\n\n#include \"configs.cuh\"\n#include \"exception.cuh\"\n#include \"launch.cuh\"\n#include \"utils.cuh\"\n\n#ifndef DISABLE_NVSHMEM\n#include \"ibgda_device.cuh\"\n#include \"nvshmem.h\"\n#endif\n\nnamespace deep_ep {\n\nnamespace intranode {\n\ntemplate <int kNumRanks>\n__global__ void barrier(int** barrier_signal_ptrs, int rank) {\n    barrier_block<kNumRanks>(barrier_signal_ptrs, rank);\n}\n\nvoid barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) {\n#define BARRIER_LAUNCH_CASE(ranks)                                  \\\n    LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \\\n    break\n\n    SETUP_LAUNCH_CONFIG(1, 32, stream);\n    SWITCH_RANKS(BARRIER_LAUNCH_CASE);\n#undef BARRIER_LAUNCH_CASE\n}\n\n}  // namespace intranode\n\nnamespace internode {\n\n#ifndef DISABLE_NVSHMEM\nnvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;\nnvshmem_team_config_t cpu_rdma_team_config;\n\nstd::vector<uint8_t> get_unique_id() {\n    nvshmemx_uniqueid_t unique_id;\n    nvshmemx_get_uniqueid(&unique_id);\n    std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));\n    std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));\n    return result;\n}\n\nint init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {\n    nvshmemx_uniqueid_t root_unique_id;\n    nvshmemx_init_attr_t attr;\n    std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));\n    nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);\n    nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);\n\n    // Create sub-RDMA teams\n    // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used\n    if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {\n        EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);\n        EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);\n        EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD,\n                                                  rank % NUM_MAX_NVL_PEERS,\n                                                  NUM_MAX_NVL_PEERS,\n                                                  num_ranks / NUM_MAX_NVL_PEERS,\n                                                  &cpu_rdma_team_config,\n                                                  0,\n                                                  &cpu_rdma_team) == 0);\n        EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);\n    }\n\n    nvshmem_barrier_all();\n    return nvshmem_my_pe();\n}\n\nvoid* alloc(size_t size, size_t alignment) {\n    return nvshmem_align(alignment, size);\n}\n\nvoid free(void* ptr) {\n    nvshmem_free(ptr);\n}\n\nvoid barrier() {\n    nvshmem_barrier_all();\n}\n\nvoid finalize() {\n    if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {\n        nvshmem_team_destroy(cpu_rdma_team);\n        cpu_rdma_team = NVSHMEM_TEAM_INVALID;\n    }\n    nvshmem_finalize();\n}\n#endif\n\n}  // namespace internode\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "csrc/kernels/utils.cuh",
    "content": "#pragma once\n\n#include \"exception.cuh\"\n\n#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC)                                                     \\\n    {                                                                                                                                 \\\n        constexpr int kLoopStride = 32 * (UNROLL_FACTOR);                                                                             \\\n        typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)];                          \\\n        auto __src = (SRC);                                                                                                           \\\n        auto __dst = (DST);                                                                                                           \\\n        for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) {                                      \\\n            _Pragma(\"unroll\") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \\\n            _Pragma(\"unroll\") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]);  \\\n        }                                                                                                                             \\\n        {                                                                                                                             \\\n            int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID);                                                                  \\\n            _Pragma(\"unroll\") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) {                                                       \\\n                if (__i + __j * 32 < (N)) {                                                                                           \\\n                    unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32);                                                           \\\n                }                                                                                                                     \\\n            }                                                                                                                         \\\n            _Pragma(\"unroll\") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) {                                                       \\\n                if (__i + __j * 32 < (N)) {                                                                                           \\\n                    ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]);                                                            \\\n                }                                                                                                                     \\\n            }                                                                                                                         \\\n        }                                                                                                                             \\\n    }\n\nnamespace deep_ep {\n\ntemplate <int kBytes>\nstruct VecInt {};\ntemplate <>\nstruct VecInt<1> {\n    using vec_t = int8_t;\n};\ntemplate <>\nstruct VecInt<2> {\n    using vec_t = int16_t;\n};\ntemplate <>\nstruct VecInt<4> {\n    using vec_t = int;\n};\ntemplate <>\nstruct VecInt<8> {\n    using vec_t = int64_t;\n};\ntemplate <>\nstruct VecInt<16> {\n    using vec_t = int4;\n};\n\ntemplate <typename FuncT>\nstruct PatternVisitor {\n    FuncT func;\n\n    __device__ __host__ explicit PatternVisitor(FuncT&& func) : func(std::forward<FuncT>(func)) {}\n\n    __device__ __host__ auto operator[](const uint32_t& i) { return func(i); }\n};\n\n__device__ __forceinline__ void trap() {\n    asm(\"trap;\");\n}\n\n__device__ __forceinline__ void memory_fence() {\n    asm volatile(\"fence.acq_rel.sys;\" ::: \"memory\");\n}\n\n__device__ __forceinline__ void memory_fence_gpu() {\n    asm volatile(\"fence.acq_rel.gpu;\" ::: \"memory\");\n}\n\n__device__ __forceinline__ void memory_fence_cta() {\n    asm volatile(\"fence.acq_rel.cta;\" ::: \"memory\");\n}\n\n__device__ __forceinline__ void st_relaxed_sys_global(const int* ptr, int val) {\n    asm volatile(\"st.relaxed.sys.global.s32 [%0], %1;\" ::\"l\"(ptr), \"r\"(val) : \"memory\");\n}\n\n__device__ __forceinline__ void st_release_sys_global(const int* ptr, int val) {\n    asm volatile(\"st.release.sys.global.s32 [%0], %1;\" ::\"l\"(ptr), \"r\"(val) : \"memory\");\n}\n\n__device__ __forceinline__ void st_release_cta(const int* ptr, int val) {\n    asm volatile(\"st.release.cta.s32 [%0], %1;\" ::\"l\"(ptr), \"r\"(val) : \"memory\");\n}\n\n__device__ __forceinline__ int ld_acquire_sys_global(const int* ptr) {\n    int ret;\n    asm volatile(\"ld.acquire.sys.global.s32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t* ptr) {\n    uint64_t ret;\n    asm volatile(\"ld.acquire.sys.global.u64 %0, [%1];\" : \"=l\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ int ld_acquire_global(const int* ptr) {\n    int ret;\n    asm volatile(\"ld.acquire.gpu.global.s32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {\n    int ret;\n    asm volatile(\"atom.add.release.sys.global.s32 %0, [%1], %2;\" : \"=r\"(ret) : \"l\"(ptr), \"r\"(value));\n    return ret;\n}\n\n__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {\n    int ret;\n    asm volatile(\"atom.add.release.gpu.global.s32 %0, [%1], %2;\" : \"=r\"(ret) : \"l\"(ptr), \"r\"(value));\n    return ret;\n}\n\n__device__ __forceinline__ int ld_acquire_cta(const int* ptr) {\n    int ret;\n    asm volatile(\"ld.acquire.cta.s32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t* ptr) {\n    uint16_t ret;\n    asm volatile(\"ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];\" : \"=h\"(ret) : \"l\"(ptr));\n    return static_cast<uint8_t>(ret);\n}\n\n__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t* ptr) {\n    uint16_t ret;\n    asm volatile(\"ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];\" : \"=h\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t* ptr) {\n    uint32_t ret;\n    asm volatile(\"ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t* ptr) {\n    uint64_t ret;\n    asm volatile(\"ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];\" : \"=l\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ int ld_volatile_global(const int* ptr) {\n    int ret;\n    asm volatile(\"ld.volatile.global.s32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ float ld_volatile_global(const float* ptr) {\n    float ret;\n    asm volatile(\"ld.volatile.global.f32 %0, [%1];\" : \"=f\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ int64_t ld_volatile_global(const int64_t* ptr) {\n    int64_t ret;\n    asm volatile(\"ld.volatile.global.s64 %0, [%1];\" : \"=l\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t* ptr) {\n    int64_t ret;\n    asm volatile(\"ld.volatile.global.u64 %0, [%1];\" : \"=l\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\n#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS\n#define LD_NC_FUNC \"ld.global.nc.L1::no_allocate.L2::256B\"\n#else\n#define LD_NC_FUNC \"ld.volatile.global\"\n#endif\n\n// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS\ntemplate <typename dtype_t>\n__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t* ptr) {\n    auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));\n    return *reinterpret_cast<dtype_t*>(&ret);\n}\n\ntemplate <>\n__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t* ptr) {\n    uint16_t ret;\n    // NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)\n    asm volatile(LD_NC_FUNC \".u8 %0, [%1];\" : \"=h\"(ret) : \"l\"(ptr));\n    return static_cast<uint8_t>(ret);\n}\n\ntemplate <>\n__device__ __forceinline__ int ld_nc_global(const int* ptr) {\n    int ret;\n    asm volatile(LD_NC_FUNC \".s32 %0, [%1];\" : \"=r\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\ntemplate <>\n__device__ __forceinline__ int64_t ld_nc_global(const int64_t* ptr) {\n    int64_t ret;\n    asm volatile(LD_NC_FUNC \".s64 %0, [%1];\" : \"=l\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\ntemplate <>\n__device__ __forceinline__ float ld_nc_global(const float* ptr) {\n    float ret;\n    asm volatile(LD_NC_FUNC \".f32 %0, [%1];\" : \"=f\"(ret) : \"l\"(ptr));\n    return ret;\n}\n\ntemplate <>\n__device__ __forceinline__ int2 ld_nc_global(const int2* ptr) {\n    int2 ret;\n    asm volatile(LD_NC_FUNC \".v2.s32 {%0, %1}, [%2];\" : \"=r\"(ret.x), \"=r\"(ret.y) : \"l\"(ptr));\n    return ret;\n}\n\ntemplate <>\n__device__ __forceinline__ int4 ld_nc_global(const int4* ptr) {\n    int4 ret;\n    asm volatile(LD_NC_FUNC \".v4.s32 {%0, %1, %2, %3}, [%4];\" : \"=r\"(ret.x), \"=r\"(ret.y), \"=r\"(ret.z), \"=r\"(ret.w) : \"l\"(ptr));\n    return ret;\n}\n\n__device__ __forceinline__ void st_na_relaxed(const uint8_t* ptr, uint8_t val) {\n    asm volatile(\"st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;\" : : \"l\"(ptr), \"h\"(static_cast<uint16_t>(val)));\n}\n\n__device__ __forceinline__ void st_na_relaxed(const uint16_t* ptr, uint16_t val) {\n    asm volatile(\"st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;\" : : \"l\"(ptr), \"h\"(val));\n}\n\n__device__ __forceinline__ void st_na_relaxed(const uint32_t* ptr, uint32_t val) {\n    asm volatile(\"st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;\" : : \"l\"(ptr), \"r\"(val));\n}\n\n__device__ __forceinline__ void st_na_relaxed(const int* ptr, int val) {\n    asm volatile(\"st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;\" : : \"l\"(ptr), \"r\"(val));\n}\n\n__device__ __forceinline__ void st_na_relaxed(const int4* ptr, int4 val) {\n    asm volatile(\"st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};\"\n                 :\n                 : \"l\"(ptr), \"r\"(val.x), \"r\"(val.y), \"r\"(val.z), \"r\"(val.w));\n}\n\n__device__ __forceinline__ void st_na_release(const int* ptr, int val) {\n    asm volatile(\"st.release.gpu.global.L1::no_allocate.b32 [%0], %1;\" : : \"l\"(ptr), \"r\"(val));\n}\n\n__device__ __forceinline__ void st_na_release(const uint32_t* ptr, uint32_t val) {\n    asm volatile(\"st.release.gpu.global.L1::no_allocate.b32 [%0], %1;\" : : \"l\"(ptr), \"r\"(val));\n}\n\n__device__ __forceinline__ void st_na_release(const uint64_t* ptr, uint64_t val) {\n    asm volatile(\"st.release.gpu.global.L1::no_allocate.b64 [%0], %1;\" : : \"l\"(ptr), \"l\"(val));\n}\n\n// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS\n#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS\n#define ST_NA_FUNC \"st.global.L1::no_allocate\"\n#else\n#define ST_NA_FUNC \"st.global\"\n#endif\n\ntemplate <typename dtype_t>\n__device__ __forceinline__ void st_na_global(const dtype_t* ptr, const dtype_t& value) {\n    st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),\n                 *reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));\n}\n\ntemplate <>\n__device__ __forceinline__ void st_na_global(const int* ptr, const int& value) {\n    asm volatile(ST_NA_FUNC \".s32 [%0], %1;\" ::\"l\"(ptr), \"r\"(value));\n}\n\ntemplate <>\n__device__ __forceinline__ void st_na_global(const int64_t* ptr, const int64_t& value) {\n    asm volatile(ST_NA_FUNC \".s64 [%0], %1;\" ::\"l\"(ptr), \"l\"(value));\n}\n\ntemplate <>\n__device__ __forceinline__ void st_na_global(const float* ptr, const float& value) {\n    asm volatile(ST_NA_FUNC \".f32 [%0], %1;\" ::\"l\"(ptr), \"f\"(value));\n}\n\ntemplate <>\n__device__ __forceinline__ void st_na_global(const int4* ptr, const int4& value) {\n    asm volatile(ST_NA_FUNC \".v4.s32 [%0], {%1, %2, %3, %4};\" ::\"l\"(ptr), \"r\"(value.x), \"r\"(value.y), \"r\"(value.z), \"r\"(value.w));\n}\n\n__device__ __forceinline__ float log2f_approx(const float& x) {\n    float ret;\n    asm volatile(\"lg2.approx.f32 %0, %1;\" : \"=f\"(ret) : \"f\"(x));\n    return ret;\n}\n\n__device__ __forceinline__ float exp2f_approx(const float& x) {\n    float ret;\n    asm volatile(\"ex2.approx.f32 %0, %1;\" : \"=f\"(ret) : \"f\"(x));\n    return ret;\n}\n\n__forceinline__ __device__ int get_lane_id() {\n    int lane_id;\n    asm(\"mov.s32 %0, %laneid;\" : \"=r\"(lane_id));\n    return lane_id;\n}\n\n__device__ __forceinline__ uint32_t elect_one_sync() {\n#ifndef DISABLE_SM90_FEATURES\n    uint32_t pred = 0;\n    asm volatile(\n        \"{\\n\"\n        \".reg .b32 %%rx;\\n\"\n        \".reg .pred %%px;\\n\"\n        \"      elect.sync %%rx|%%px, %1;\\n\"\n        \"@%%px mov.s32 %0, 1;\\n\"\n        \"}\\n\"\n        : \"+r\"(pred)\n        : \"r\"(0xffffffff));\n    return pred;\n#else\n    return get_lane_id() == 0;\n#endif\n}\n\n// TMA PTX instructions\n#ifndef DISABLE_SM90_FEATURES\n\n__device__ __forceinline__ void fence_barrier_init() {\n    asm volatile(\"fence.mbarrier_init.release.cluster; \\n\" ::);\n}\n\n__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    asm volatile(\"mbarrier.init.shared::cta.b64 [%1], %0;\" ::\"r\"(arrive_count), \"r\"(mbar_int_ptr));\n}\n\n__device__ __forceinline__ void mbarrier_inval(uint64_t* mbar_ptr) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    asm volatile(\"mbarrier.inval.shared::cta.b64 [%0];\" ::\"r\"(mbar_int_ptr));\n}\n\ntemplate <bool kWithMultiStages = false>\n__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase, int stage_idx = 0) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    const auto& wait = kWithMultiStages ? (phase >> stage_idx) & 1 : phase;\n    asm volatile(\n        \"{\\n\\t\"\n        \".reg .pred       P1; \\n\\t\"\n        \"LAB_WAIT: \\n\\t\"\n        \"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \\n\\t\"\n        \"@P1 bra DONE; \\n\\t\"\n        \"bra     LAB_WAIT; \\n\\t\"\n        \"DONE: \\n\\t\"\n        \"}\" ::\"r\"(mbar_int_ptr),\n        \"r\"(wait),\n        \"r\"(0x989680));\n    phase ^= kWithMultiStages ? (1 << stage_idx) : 1;\n}\n\n__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    asm volatile(\"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \\n\\t\" ::\"r\"(num_bytes), \"r\"(mbar_int_ptr));\n}\n\n__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar_ptr) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    asm volatile(\"mbarrier.arrive.shared::cta.b64 _, [%0]; \\n\\t\" ::\"r\"(mbar_int_ptr));\n}\n\n__device__ __forceinline__ void tma_store_fence() {\n    asm volatile(\"fence.proxy.async.shared::cta;\");\n}\n\nconstexpr uint64_t kEvictFirst = 0x12f0000000000000;\nconstexpr uint64_t kEvictNormal = 0x1000000000000000;\n\n__device__ __forceinline__ void tma_load_1d(\n    const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes, bool evict_first = true) {\n    auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));\n    auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;\n    asm volatile(\n        \"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\\n\" ::\"r\"(smem_int_ptr),\n        \"l\"(gmem_ptr),\n        \"r\"(num_bytes),\n        \"r\"(mbar_int_ptr),\n        \"l\"(cache_hint)\n        : \"memory\");\n}\n\n__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes, bool evict_first = true) {\n    auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;\n    asm volatile(\"cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\\n\" ::\"l\"(gmem_ptr),\n                 \"r\"(smem_int_ptr),\n                 \"r\"(num_bytes),\n                 \"l\"(cache_hint)\n                 : \"memory\");\n    asm volatile(\"cp.async.bulk.commit_group;\");\n}\n\ntemplate <int N>\n__device__ __forceinline__ void tma_store_wait() {\n    asm volatile(\"cp.async.bulk.wait_group.read %0;\" ::\"n\"(N) : \"memory\");\n}\n\n#endif\n\ntemplate <typename dtype_t>\n__host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {\n    return (a + b - 1) / b;\n}\n\ntemplate <typename dtype_t>\n__host__ __device__ constexpr dtype_t align_up(dtype_t a, dtype_t b) {\n    return ceil_div<dtype_t>(a, b) * b;\n}\n\ntemplate <typename dtype_t>\n__host__ __device__ constexpr dtype_t align_down(dtype_t a, dtype_t b) {\n    return a / b * b;\n}\n\n__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id, int& token_start_idx, int& token_end_idx) {\n    int num_tokens_per_sm = ceil_div(num_tokens, num_sms);\n    token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);\n    token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);\n}\n\ntemplate <typename dtype_a_t, typename dtype_b_t>\n__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {\n    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), \"Invalid dtypes\");\n    dtype_b_t packed;\n    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);\n    unpacked_ptr[0] = x, unpacked_ptr[1] = y;\n    return packed;\n}\n\ntemplate <typename dtype_a_t, typename dtype_b_t>\n__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {\n    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), \"Invalid dtypes\");\n    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);\n    x = unpacked_ptr[0], y = unpacked_ptr[1];\n}\n\ntemplate <typename dtype_t>\n__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {\n    EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, \"\");\n    auto send_int_values = reinterpret_cast<int*>(&ptr);\n    int recv_int_values[sizeof(dtype_t) / sizeof(int)];\n    #pragma unroll\n    for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++i)\n        recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);\n    return *reinterpret_cast<dtype_t*>(recv_int_values);\n}\n\nconstexpr float kFP8Margin = 1e-4;\nconstexpr float kFinfoAmaxE4M3 = 448.0f;\nconstexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;\n\n__forceinline__ __device__ float fast_pow2(int x) {\n    // We can ensure `-126 <= x and x <= 127`\n    uint32_t bits_x = (x + 127) << 23;\n    return *reinterpret_cast<float*>(&bits_x);\n}\n\n__forceinline__ __device__ int fast_log2_ceil(float x) {\n    auto bits_x = *reinterpret_cast<uint32_t*>(&x);\n    auto exp_x = (bits_x >> 23) & 0xff;\n    auto man_bits = bits_x & ((1 << 23) - 1);\n    return exp_x - 127 + (man_bits != 0);\n}\n\n__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {\n    if (round_scale) {\n        auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);\n        scale = fast_pow2(-exp_scale_inv);\n        scale_inv = fast_pow2(exp_scale_inv);\n    } else {\n        scale_inv = amax * kFinfoAmaxInvE4M3;\n        scale = kFinfoAmaxE4M3 / amax;\n    }\n}\n\ntemplate <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>\n__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {\n    if constexpr (kIsUE8M0) {\n        return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);\n    } else {\n        return value;\n    }\n}\n\ntemplate <int kNumRanks, bool kSyncOnly = false>\n__forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, int rank) {\n    auto thread_id = static_cast<int>(threadIdx.x);\n\n    // For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope\n    if constexpr (not kSyncOnly) {\n        memory_fence();\n        __syncthreads();\n    }\n\n    // Add self-ranks, sub other ranks\n    if (thread_id < kNumRanks) {\n        atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);\n        atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);\n    }\n    EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);\n\n    // Check timeout\n    auto start_time = clock64();\n    while (true) {\n        auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;\n        if (__all_sync(0xffffffff, value <= 0))\n            break;\n\n        if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) {\n            printf(\"DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\\n\", rank, thread_id, value);\n            trap();\n        }\n    }\n    __syncthreads();\n}\n\n__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) {\n    int ret;\n    asm volatile(\"atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;\" : \"=r\"(ret) : \"l\"(addr), \"r\"(x), \"r\"(y) : \"memory\");\n    return ret;\n}\n\n__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) {\n    int ret;\n    asm volatile(\"atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;\" : \"=r\"(ret) : \"l\"(addr), \"r\"(x) : \"memory\");\n    return ret;\n}\n\n__forceinline__ __device__ void acquire_lock(int* mutex) {\n    // To make later memory operations valid, we must use `acquire` for memory semantics\n    while (atomic_cas_cta_acquire(mutex, 0, 1) != 0)\n        ;\n}\n\n__forceinline__ __device__ void release_lock(int* mutex) {\n    // To make previous memory operations visible to other threads, we must use `release` for memory semantics\n    atomic_exch_cta_release(mutex, 0);\n}\n\n// Operation functors\ntemplate <typename T>\nstruct ReduceSum {\n    __device__ T operator()(T a, T b) const { return a + b; }\n};\ntemplate <typename T>\nstruct ReduceMax {\n    __device__ T operator()(T a, T b) const { return a > b ? a : b; }\n};\ntemplate <typename T>\nstruct ReduceMin {\n    __device__ T operator()(T a, T b) const { return a < b ? a : b; }\n};\ntemplate <typename T>\nstruct ReduceAnd {\n    __device__ T operator()(T a, T b) const { return a & b; }\n};\ntemplate <typename T>\nstruct ReduceOr {\n    __device__ T operator()(T a, T b) const { return a | b; }\n};\n\n// Unified reduction function\ntemplate <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>\n__forceinline__ __device__ T warp_reduce(T value, Op op) {\n    EP_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or kNumLanesPerGroup == 4 or\n                         kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,\n                     \"Invalid number of lanes\");\n    constexpr uint32_t mask = 0xffffffff;\n    if constexpr (kIntergroupReduce) {\n        if constexpr (kNumLanesPerGroup <= 1)\n            value = op(value, __shfl_xor_sync(mask, value, 1));\n        if constexpr (kNumLanesPerGroup <= 2)\n            value = op(value, __shfl_xor_sync(mask, value, 2));\n        if constexpr (kNumLanesPerGroup <= 4)\n            value = op(value, __shfl_xor_sync(mask, value, 4));\n        if constexpr (kNumLanesPerGroup <= 8)\n            value = op(value, __shfl_xor_sync(mask, value, 8));\n        if constexpr (kNumLanesPerGroup <= 16)\n            value = op(value, __shfl_xor_sync(mask, value, 16));\n    } else {\n        if constexpr (kNumLanesPerGroup >= 32)\n            value = op(value, __shfl_xor_sync(mask, value, 16));\n        if constexpr (kNumLanesPerGroup >= 16)\n            value = op(value, __shfl_xor_sync(mask, value, 8));\n        if constexpr (kNumLanesPerGroup >= 8)\n            value = op(value, __shfl_xor_sync(mask, value, 4));\n        if constexpr (kNumLanesPerGroup >= 4)\n            value = op(value, __shfl_xor_sync(mask, value, 2));\n        if constexpr (kNumLanesPerGroup >= 2)\n            value = op(value, __shfl_xor_sync(mask, value, 1));\n    }\n    return value;\n}\n\n// Convenience aliases\ntemplate <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>\n__forceinline__ __device__ T warp_reduce_sum(T value) {\n    return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});\n}\n\ntemplate <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>\n__forceinline__ __device__ T warp_reduce_max(T value) {\n    return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMax<T>{});\n}\n\ntemplate <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>\n__forceinline__ __device__ T warp_reduce_min(T value) {\n    return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMin<T>{});\n}\n\ntemplate <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>\n__forceinline__ __device__ T warp_reduce_and(T value) {\n    return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceAnd<T>{});\n}\n\ntemplate <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>\n__forceinline__ __device__ T warp_reduce_or(T value) {\n    return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});\n}\n\n}  // namespace deep_ep\n"
  },
  {
    "path": "deep_ep/__init__.py",
    "content": "import torch\n\nfrom .utils import EventOverlap\nfrom .buffer import Buffer\n\n# noinspection PyUnresolvedReferences\nfrom deep_ep_cpp import Config, topk_idx_t\n"
  },
  {
    "path": "deep_ep/buffer.py",
    "content": "import os\nimport torch\nimport torch.distributed as dist\nfrom typing import Callable, List, Tuple, Optional, Union\n\n# noinspection PyUnresolvedReferences\nimport deep_ep_cpp\n# noinspection PyUnresolvedReferences\nfrom deep_ep_cpp import Config, EventHandle\nfrom .utils import EventOverlap, check_nvlink_connections\n\n\nclass Buffer:\n    \"\"\"\n    The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:\n        - high-throughput intranode all-to-all (dispatch and combine, using NVLink)\n        - high-throughput internode all-to-all (dispatch and combine, using RDMA and NVLink)\n        - low-latency all-to-all (dispatch and combine, using RDMA)\n\n    Attributes:\n        num_sms: the SMs used in high-throughput kernels.\n        rank: the local rank number.\n        group_size: the number of ranks in the group.\n        group: the communication group.\n        num_nvl_bytes: the buffer size for intranode NVLink communication.\n        num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.\n        runtime: the C++ runtime.\n    \"\"\"\n\n    num_sms: int = 20\n\n    def __init__(self,\n                 group: Optional[dist.ProcessGroup],\n                 num_nvl_bytes: int = 0,\n                 num_rdma_bytes: int = 0,\n                 low_latency_mode: bool = False,\n                 num_qps_per_rank: int = 24,\n                 allow_nvlink_for_low_latency_mode: bool = True,\n                 allow_mnnvl: bool = False,\n                 use_fabric: bool = False,\n                 explicitly_destroy: bool = False,\n                 enable_shrink: bool = False,\n                 comm: Optional[\"mpi4py.MPI.Comm\"] = None) -> None:  # noqa: F821\n        \"\"\"\n        Initialize the communication buffer.\n\n        Arguments:\n            group: the communication group.\n            num_nvl_bytes: the buffer size for intranode NVLink communication.\n            num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.\n            low_latency_mode: whether to enable low-latency mode.\n            num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals\n                to the number of local experts.\n            allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice\n                this is somehow incompatible with the hook-based overlapping.\n                Warning: PCIe connections may lead to errors due to memory ordering issues,\n                please make sure all connections are via NVLink.\n            allow_mnnvl: whether to allow MNNVL\n            use_fabric: whether to use fabric API for memory buffers.\n            enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.\n            explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;\n                otherwise, the resources will be released by the destructor.\n                Note: Releasing resources in the destructor may cause Python's exception handling process to hang.\n            comm: the `mpi4py.MPI.Comm` communicator to use in case the group parameter is absent.\n        \"\"\"\n        check_nvlink_connections(group)\n\n        # Initialize the CPP runtime\n        if group is not None:\n            self.rank = group.rank()\n            self.group = group\n            self.group_size = group.size()\n\n            def all_gather_object(obj):\n                object_list = [None] * self.group_size\n                dist.all_gather_object(object_list, obj, group)\n                return object_list\n        elif comm is not None:\n            self.rank = comm.Get_rank()\n            self.group = comm\n            self.group_size = comm.Get_size()\n\n            def all_gather_object(obj):\n                return comm.allgather(obj)\n        else:\n            raise ValueError(\"Either 'group' or 'comm' must be provided.\")\n        self.num_nvl_bytes = num_nvl_bytes\n        self.num_rdma_bytes = num_rdma_bytes\n        self.low_latency_mode = low_latency_mode\n        self.explicitly_destroy = explicitly_destroy\n        self.enable_shrink = enable_shrink\n        self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy,\n                                          enable_shrink, use_fabric)\n\n        # Synchronize device IDs\n        local_device_id = self.runtime.get_local_device_id()\n        device_ids = all_gather_object(local_device_id)\n\n        # Synchronize IPC handles\n        local_ipc_handle = self.runtime.get_local_ipc_handle()\n        ipc_handles = all_gather_object(local_ipc_handle)\n\n        # Synchronize NVSHMEM unique IDs\n        root_unique_id = None\n        if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:\n            # Enable IBGDA\n            assert num_qps_per_rank > 0\n            os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'\n            os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'\n            os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'\n\n            # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check\n            self.nvshmem_qp_depth = int(os.environ.get('NVSHMEM_QP_DEPTH', '1024'))\n            os.environ['NVSHMEM_QP_DEPTH'] = str(self.nvshmem_qp_depth)\n\n            # Reduce gpu memory usage\n            # 6 default teams + 1 extra team\n            os.environ['NVSHMEM_MAX_TEAMS'] = '7'\n            # Disable NVLink SHArP\n            os.environ['NVSHMEM_DISABLE_NVLS'] = '1'\n            # NOTES: NVSHMEM initialization requires at least 256 MiB\n            os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'\n\n            if not allow_mnnvl:\n                # Disable multi-node NVLink detection\n                os.environ['NVSHMEM_DISABLE_MNNVL'] = '1'\n\n            # Synchronize using the root ID\n            if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):\n                root_unique_id = self.runtime.get_local_nvshmem_unique_id()\n            nvshmem_unique_ids = all_gather_object(root_unique_id)\n            root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]\n\n        # Make CPP runtime available\n        self.runtime.sync(device_ids, ipc_handles, root_unique_id)\n        assert self.runtime.is_available()\n\n    def destroy(self):\n        \"\"\"\n        Destroy the cpp runtime and release resources.\n\n        \"\"\"\n\n        assert self.explicitly_destroy, '`explicitly_destroy` flag must be set'\n\n        self.runtime.destroy()\n        self.runtime = None\n\n    @staticmethod\n    def is_sm90_compiled():\n        return deep_ep_cpp.is_sm90_compiled()\n\n    @staticmethod\n    def set_num_sms(new_num_sms: int) -> None:\n        \"\"\"\n        Set the number of SMs to use in high-throughput kernels.\n\n        Arguments:\n            new_num_sms: the new number to be set.\n        \"\"\"\n\n        assert new_num_sms % 2 == 0, 'The SM count must be even'\n        Buffer.num_sms = new_num_sms\n\n    @staticmethod\n    def capture() -> EventOverlap:\n        \"\"\"\n        Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`.\n\n        Returns:\n            event: the captured event.\n        \"\"\"\n        return EventOverlap(EventHandle())\n\n    @staticmethod\n    def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:\n        \"\"\"\n        Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.\n\n        Arguments:\n            num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.\n            hidden: the hidden dimension of each token.\n            num_ranks: the number of EP group ranks.\n            num_experts: the number of all experts.\n\n        Returns:\n            size: the RDMA buffer size recommended.\n        \"\"\"\n        return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)\n\n    def get_comm_stream(self) -> torch.Stream:\n        \"\"\"\n        Get the communication stream.\n\n        Returns:\n            stream: the communication stream.\n        \"\"\"\n        ts: torch.Stream = self.runtime.get_comm_stream()\n        return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type)\n\n    def get_local_buffer_tensor(self,\n                                dtype: torch.dtype,\n                                size: Optional[torch.Size] = None,\n                                offset: int = 0,\n                                use_rdma_buffer: bool = False) -> torch.Tensor:\n        \"\"\"\n        Get the raw buffer (slice supported) as a PyTorch tensor.\n\n        Argument:\n            dtype: the data type (PyTorch `dtype`) for the tensor.\n            size: the slice size (by elements) to get from the buffer.\n            offset: the offset of the beginning element.\n            use_rdma_buffer: whether to return the RDMA buffer.\n        \"\"\"\n        tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer)\n        if size is None:\n            return tensor\n\n        assert tensor.numel() >= size.numel()\n        return tensor[:size.numel()].view(size)\n\n    @staticmethod\n    def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):\n        bias_0, bias_1 = None, None\n        if isinstance(bias, torch.Tensor):\n            bias_0 = bias\n        elif isinstance(bias, tuple):\n            assert len(bias) == 2\n            bias_0, bias_1 = bias\n        return bias_0, bias_1\n\n    @staticmethod\n    def get_dispatch_config(num_ranks: int) -> Config:\n        \"\"\"\n        Get a recommended dispatch config.\n\n        Argument:\n            num_ranks: the number of ranks.\n\n        Returns:\n            config: the recommended config.\n        \"\"\"\n\n        # TODO: automatically tune\n        config_map = {\n            2: Config(Buffer.num_sms, 24, 256, 6, 128),\n            4: Config(Buffer.num_sms, 6, 256, 6, 128),\n            8: Config(Buffer.num_sms, 6, 256, 6, 128),\n            16: Config(Buffer.num_sms, 36, 288, 20, 128),\n            24: Config(Buffer.num_sms, 32, 288, 8, 128),\n            32: Config(Buffer.num_sms, 32, 288, 8, 128),\n            48: Config(Buffer.num_sms, 32, 288, 8, 128),\n            64: Config(Buffer.num_sms, 32, 288, 8, 128),\n            96: Config(Buffer.num_sms, 20, 480, 12, 128),\n            128: Config(Buffer.num_sms, 20, 560, 12, 128),\n            144: Config(Buffer.num_sms, 32, 720, 12, 128),\n            160: Config(Buffer.num_sms, 28, 720, 12, 128),\n        }\n        assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'\n        return config_map[num_ranks]\n\n    @staticmethod\n    def get_combine_config(num_ranks: int) -> Config:\n        \"\"\"\n        Get a recommended combine config.\n\n        Argument:\n            num_ranks: the number of ranks.\n\n        Returns:\n            config: the recommended config.\n        \"\"\"\n\n        # TODO: automatically tune\n        config_map = {\n            2: Config(Buffer.num_sms, 10, 256, 6, 128),\n            4: Config(Buffer.num_sms, 9, 256, 6, 128),\n            8: Config(Buffer.num_sms, 4, 256, 6, 128),\n            16: Config(Buffer.num_sms, 4, 288, 12, 128),\n            24: Config(Buffer.num_sms, 1, 288, 8, 128),\n            32: Config(Buffer.num_sms, 1, 288, 8, 128),\n            48: Config(Buffer.num_sms, 1, 288, 8, 128),\n            64: Config(Buffer.num_sms, 1, 288, 8, 128),\n            96: Config(Buffer.num_sms, 1, 480, 8, 128),\n            128: Config(Buffer.num_sms, 1, 560, 8, 128),\n            144: Config(Buffer.num_sms, 2, 720, 8, 128),\n            160: Config(Buffer.num_sms, 2, 720, 8, 128),\n        }\n        assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'\n        return config_map[num_ranks]\n\n    # noinspection PyTypeChecker\n    def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,\n                            previous_event: Optional[EventOverlap] = None, async_finish: bool = False,\n                            allocate_on_comm_stream: bool = False) -> \\\n            Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:\n        \"\"\"\n        Calculate the layout required for later communication.\n\n        Arguments:\n            topk_idx: `[num_tokens, num_topk]`, dtype must be `deep_ep.topk_idx_t` (typically `torch.int64`), the expert\n                indices selected by each token, `-1` means no selections.\n            num_experts: the number of experts.\n            previous_event: the event to wait before actually executing the kernel.\n            async_finish: the current stream will not wait for the communication kernels to be finished if set.\n            allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.\n\n        Returns:\n            num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.\n            num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA\n                rank (with the same GPU index), return `None` for intranode settings.\n            num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.\n            is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.\n            event: the event after executing the kernel (valid only if `async_finish` is set).\n        \"\"\"\n        num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \\\n            self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),\n                                             async_finish, allocate_on_comm_stream)\n        return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)\n\n    # noinspection PyTypeChecker\n    def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n                 handle: Optional[Tuple] = None,\n                 num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,\n                 is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,\n                 topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,\n                 expert_alignment: int = 1, num_worst_tokens: int = 0,\n                 config: Optional[Config] = None,\n                 previous_event: Optional[EventOverlap] = None, async_finish: bool = False,\n                 allocate_on_comm_stream: bool = False) -> \\\n            Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],\n                  Optional[torch.Tensor], List[int], Tuple, EventOverlap]:\n        \"\"\"\n        Dispatch tokens to different ranks, both intranode and internode settings are supported.\n        Intranode kernels require all the ranks should be visible via NVLink.\n        Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU\n            index should be visible via RDMA.\n\n        Arguments:\n            x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,\n                and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as\n                `[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`\n                 (requiring divisible) with type `torch.float`.\n            handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.\n            num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.\n            num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA\n                rank (with the same GPU index), return `None` for intranode settings.\n            is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.\n            num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.\n            topk_idx: `[num_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert indices\n                selected by each token, `-1` means no selections.\n            topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.\n            expert_alignment: align the number of tokens received by each local expert to this variable.\n            num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it\n                will be CUDA-graph compatible. Please also notice that this flag is for intranode only.\n            config: the performance tuning config.\n            previous_event: the event to wait before actually executing the kernel.\n            async_finish: the current stream will not wait for the communication kernels to be finished if set.\n            allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.\n\n        Returns:\n            recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the\n                received token count.\n            recv_topk_idx: received expert indices.\n            recv_topk_weights: received expert weights.\n            num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by\n                each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list\n                will be empty.\n            handle: the returned communication handle.\n            event: the event after executing the kernel (valid only if `async_finish` is set).\n        \"\"\"\n        # Default config\n        config = self.get_dispatch_config(self.group_size) if config is None else config\n\n        # Internode\n        if self.runtime.get_num_rdma_ranks() > 1:\n            return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank,\n                                           num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, num_worst_tokens, config,\n                                           previous_event, async_finish, allocate_on_comm_stream)\n\n        # Launch the kernel with cached or non-cached mode\n        x, x_scales = x if isinstance(x, tuple) else (x, None)\n        if handle is not None:\n            assert topk_idx is None and topk_weights is None\n            rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle\n            num_recv_tokens = recv_src_idx.size(0)\n            recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(\n                x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,\n                expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)\n            return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)\n        else:\n            assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None\n            recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \\\n                self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,\n                                                num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,\n                                                expert_alignment, num_worst_tokens, config,\n                                                getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)\n            handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)\n            return (\n                recv_x, recv_x_scales\n            ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(\n                event)\n\n    # noinspection PyTypeChecker\n    def combine(self, x: torch.Tensor, handle: Tuple,\n                topk_weights: Optional[torch.Tensor] = None,\n                bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,\n                config: Optional[Config] = None,\n                previous_event: Optional[EventOverlap] = None, async_finish: bool = False,\n                allocate_on_comm_stream: bool = False) -> \\\n            Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:\n        \"\"\"\n        Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode\n            settings are supported.\n        Intranode kernels require all the ranks should be visible via NVLink.\n        Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU\n            index should be visible via RDMA.\n\n        Arguments:\n            x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.\n            handle: a must-set communication handle, you can obtain this from the dispatch function.\n            topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.\n            bias: 0, 1 or 2 `[num_tokens, hidden]` with `torch.bfloat16` final bias to the output.\n            config: the performance tuning config.\n            previous_event: the event to wait before actually executing the kernel.\n            async_finish: the current stream will not wait for the communication kernels to be finished if set.\n            allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.\n\n        Returns:\n            recv_x: the reduced token from its dispatched ranks.\n            recv_topk_weights: the reduced top-k weights from its dispatch ranks.\n            event: the event after executing the kernel (valid only if `async_finish` is set).\n        \"\"\"\n        # Default config\n        config = self.get_combine_config(self.group_size) if config is None else config\n\n        # Internode\n        if self.runtime.get_num_rdma_ranks() > 1:\n            return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream)\n\n        # NOTES: the second `_` is for the sending side, so we should use the third one\n        rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle\n        bias_0, bias_1 = Buffer._unpack_bias(bias)\n\n        # Launch the kernel\n        recv_x, recv_topk_weights, event = self.runtime.intranode_combine(x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix,\n                                                                          channel_prefix_matrix, send_head, config,\n                                                                          getattr(previous_event, 'event',\n                                                                                  None), async_finish, allocate_on_comm_stream)\n        return recv_x, recv_topk_weights, EventOverlap(event)\n\n    # noinspection PyTypeChecker\n    def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n                           handle: Optional[Tuple] = None,\n                           num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,\n                           is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,\n                           topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,\n                           num_worst_tokens: int = 0, config: Optional[Config] = None,\n                           previous_event: Optional[EventOverlap] = None, async_finish: bool = False,\n                           allocate_on_comm_stream: bool = False) -> \\\n            Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],\n            Optional[torch.Tensor], List[int], Tuple, EventOverlap]:\n        \"\"\"\n        Internode dispatch implementation, for more details, please refer to the `dispatch` docs.\n        Normally, you should not directly call this function.\n        \"\"\"\n        assert config is not None\n\n        # Launch the kernel with cached or non-cached mode\n        x, x_scales = x if isinstance(x, tuple) else (x, None)\n        if handle is not None:\n            assert topk_idx is None and topk_weights is None\n            is_token_in_rank, \\\n                rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \\\n                recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \\\n                recv_src_meta, send_rdma_head, send_nvl_head = handle\n            num_recv_tokens = recv_src_meta.size(0)\n            num_rdma_recv_tokens = send_nvl_head.size(0)\n            recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(\n                x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens,\n                rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,\n                expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)\n            return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)\n        else:\n            assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None\n            recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \\\n                rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \\\n                recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \\\n                recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \\\n                recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(\n                x, x_scales, topk_idx, topk_weights,\n                num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,\n                0, 0, None, None, None, None,\n                expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)\n            handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix,\n                      recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head,\n                      send_nvl_head)\n            return (\n                recv_x, recv_x_scales\n            ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(\n                event)\n\n    # noinspection PyTypeChecker\n    def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],\n                          topk_weights: Optional[torch.Tensor] = None,\n                          bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,\n                          config: Optional[Config] = None,\n                          previous_event: Optional[EventOverlap] = None, async_finish: bool = False,\n                          allocate_on_comm_stream: bool = False) -> \\\n            Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:\n        \"\"\"\n        Internode combine implementation, for more details, please refer to the `combine` docs.\n        Normally, you should not directly call this function.\n        \"\"\"\n        assert config is not None\n\n        # Unpack handle and bias\n        is_combined_token_in_rank, \\\n            _, _, \\\n            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \\\n            src_meta, send_rdma_head, send_nvl_head = handle\n        bias_0, bias_1 = Buffer._unpack_bias(bias)\n\n        # Launch the kernel\n        combined_x, combined_topk_weights, event = self.runtime.internode_combine(x, topk_weights, bias_0, bias_1, src_meta,\n                                                                                  is_combined_token_in_rank, rdma_channel_prefix_matrix,\n                                                                                  rdma_rank_prefix_sum, gbl_channel_prefix_matrix,\n                                                                                  send_rdma_head, send_nvl_head, config,\n                                                                                  getattr(previous_event, 'event',\n                                                                                          None), async_finish, allocate_on_comm_stream)\n        return combined_x, combined_topk_weights, EventOverlap(event)\n\n    def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:\n        \"\"\"\n        As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer\n            if the buffer is dirty at some time.\n        For example, after running the normal dispatch/combine, you must run this function before executing any\n            low-latency kernel.\n\n        Arguments:\n            num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.\n            hidden: the hidden dimension of each token.\n            num_experts: the number of all experts.\n        \"\"\"\n        self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)\n\n    # noinspection PyTypeChecker\n    def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,\n                             num_max_dispatch_tokens_per_rank: int, num_experts: int,\n                             cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,\n                             dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,\n                             use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,\n                             async_finish: bool = False, return_recv_hook: bool = False) -> \\\n            Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:\n        \"\"\"\n        A low-latency implementation for dispatching with IBGDA.\n        This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA\n            (specifically, IBGDA must be enabled).\n        Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2\n            low-latency kernels' result tensors at a single moment.\n\n        Arguments:\n            x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are\n                supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.\n            topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`,\n                only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.\n            num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.\n            num_experts: the number of all experts.\n            cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape\n                `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance\n                monitoring.\n            dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,\n                which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.\n                This is useful for detecting and precisely localizing slow anomalies.\n            use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.\n            round_scale: whether round the scaling factors into power of 2.\n            use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).\n            async_finish: the current stream will not wait for the communication kernels to be finished if set.\n            return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,\n                but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.\n                If you do not set this flag, the kernel will ensure the data's arrival.\n\n        Returns:\n            recv_x: a tensor or tuple with received tokens for each expert.\n                With `use_fp8=True`: the first element is a `torch.Tensor` shaped as\n                `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.\n                The second tensor is the corresponding scales for the first element with shape\n                `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,\n                if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as\n                `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.\n                Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.\n                With `use_fp8=False`, the result would be a tensor shaped as\n                `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.\n                Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,\n                as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).\n            recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each\n                expert receives. As mentioned before, not all tokens are valid in `recv_x`.\n            handle: the communication handle to be used in the `low_latency_combine` function.\n            event: the event after executing the kernel (valid only if `async_finish` is set).\n            hook: the receiving hook function (valid only if `return_recv_hook` is set).\n        \"\"\"\n        assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2\n        packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \\\n            self.runtime.low_latency_dispatch(x, topk_idx,\n                                              cumulative_local_expert_recv_stats,\n                                              dispatch_wait_recv_cost_stats,\n                                              num_max_dispatch_tokens_per_rank, num_experts,\n                                              use_fp8, round_scale, use_ue8m0,\n                                              async_finish, return_recv_hook)\n        handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)\n        tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info,\n                             packed_recv_layout_range, cumulative_local_expert_recv_stats)\n        return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \\\n            EventOverlap(event, tensors_to_record if async_finish else None), hook\n\n    # noinspection PyTypeChecker\n    def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,\n                            handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,\n                            return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,\n                            combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \\\n            Tuple[torch.Tensor, EventOverlap, Callable]:\n        \"\"\"\n        A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.\n        This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA\n            (specifically, IBGDA must be enabled).\n        Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2\n            low-latency kernels' result tensors at a single moment.\n\n        Arguments:\n            x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,\n                the local calculated tokens to be sent to this original rank and reduced.\n            topk_idx: `[num_combined_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert\n                indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that,\n                `num_combined_tokens` equals to the number of dispatched tokens.\n            topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched\n                tokens. The received tokens will be reduced with the weights in this tensor.\n            handle: the communication handle given by the `dispatch` function.\n            use_logfmt: whether to use an internal \"LogFMT with dynamic per-64-channel cast\" format (10 bits).\n            zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative\n                with `get_next_low_latency_combine_buffer`.\n            async_finish: the current stream will not wait for the communication kernels to be finished if set.\n            return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,\n                but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.\n                If you do not set this flag, the kernel will ensure the data's arrival.\n            out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.\n            combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,\n                which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.\n                This is useful for detecting and pre-cisely localizing slow anomalies.\n\n        Returns:\n            combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`.\n            event: the event after executing the kernel (valid only if `async_finish` is set).\n            hook: the receiving hook function (valid only if `return_recv_hook` is set).\n        \"\"\"\n        src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle\n        assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2\n        combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,\n                                                                   combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank,\n                                                                   num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, out)\n        tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)\n        return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook\n\n    def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False):\n        \"\"\"\n        Mask (unmask) a rank during communication (dispatch, combine, and clean)\n\n        Arguments:\n            rank: the rank to mask (unmask).\n            mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.\n\n        \"\"\"\n        self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask)\n\n    def low_latency_query_mask_buffer(self, mask_status: torch.Tensor):\n        \"\"\"\n        Query the mask status of all ranks\n\n        Arguments:\n            mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.\n\n        \"\"\"\n        self.runtime.low_latency_query_mask_buffer(mask_status)\n\n    def low_latency_clean_mask_buffer(self):\n        \"\"\"\n        Clean the mask buffer\n\n        \"\"\"\n        self.runtime.low_latency_clean_mask_buffer()\n\n    def get_next_low_latency_combine_buffer(self, handle: object):\n        \"\"\"\n        Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying.\n\n        Arguments:\n            handle: the communication handle given by the `dispatch` function.\n\n        Returns:\n            buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape\n                `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer\n                by yourself.\n        \"\"\"\n        src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle\n        return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)\n"
  },
  {
    "path": "deep_ep/utils.py",
    "content": "import os\nimport torch\nimport torch.distributed as dist\nfrom typing import Any, Optional, Tuple\n\n# noinspection PyUnresolvedReferences\nfrom deep_ep_cpp import EventHandle\n\n\nclass EventOverlap:\n    \"\"\"\n    A wrapper class to manage CUDA events, also for better overlapping convenience.\n\n    Attributes:\n        event: the CUDA event captured.\n        extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.\n    \"\"\"\n\n    def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:\n        \"\"\"\n        Initialize the class.\n\n        Arguments:\n            event: the CUDA event captured.\n            extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.\n        \"\"\"\n        self.event = event\n\n        # NOTES: we use extra tensors to achieve stream recording, otherwise,\n        # stream recording will be incompatible with CUDA graph.\n        self.extra_tensors = extra_tensors\n\n    def current_stream_wait(self) -> None:\n        \"\"\"\n        The current stream `torch.cuda.current_stream()` waits for the event to be finished.\n        \"\"\"\n        assert self.event is not None\n        self.event.current_stream_wait()\n\n    def __enter__(self) -> Any:\n        \"\"\"\n        Utility for overlapping and Python `with` syntax.\n\n        You can overlap the kernels on the current stream with the following example:\n        ```python\n        event_overlap = event_after_all_to_all_kernels()\n        with event_overlap():\n            do_something_on_current_stream()\n        # After exiting the `with` scope, the current stream with wait the event to be finished.\n        ```\n        \"\"\"\n        return self\n\n    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:\n        \"\"\"\n        Utility for overlapping and Python `with` syntax.\n\n        Please follow the example in the `__enter__` function.\n        \"\"\"\n        if self.event is not None:\n            self.event.current_stream_wait()\n\n\ndef check_nvlink_connections(group: dist.ProcessGroup):\n    \"\"\"\n    Check NVLink connection between every pair of GPUs.\n\n    Arguments:\n        group: the communication group.\n    \"\"\"\n    # Check NVLink connection\n    # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2\n    # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink\n    if 'PCIE' in torch.cuda.get_device_name():\n        assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections'\n\n        # noinspection PyUnresolvedReferences\n        import pynvml\n        pynvml.nvmlInit()\n\n        # noinspection PyTypeChecker\n        devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',')\n        physical_device_idx = int(devices[torch.cuda.current_device()])\n        physical_device_indices = [\n            0,\n        ] * group.size()\n        dist.all_gather_object(physical_device_indices, physical_device_idx, group)\n\n        # Check whether they are all connected via NVLink\n        # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438\n        handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices]\n        for i, handle in enumerate(handles):\n            for j, peer_handle in enumerate(handles):\n                if i >= j:\n                    continue\n                status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)\n                assert status == pynvml.NVML_P2P_STATUS_OK,\\\n                    f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink'\n\n        # Close NVML\n        pynvml.nvmlShutdown()\n"
  },
  {
    "path": "format.sh",
    "content": "#!/usr/bin/env bash\n# Usage:\n#    # Do work and commit your work.\n\n#    # Format files that differ from origin/main.\n#    bash format.sh\n\n#    # Commit changed files with message 'Run yapf and ruff'\n#\n#\n# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.\n# You are encouraged to run this locally before pushing changes for review.\n\n# Cause the script to exit if a single command fails\nset -eo pipefail\n\n# If yapf/ruff is not installed, install according to the requirements\nif ! (yapf --version &>/dev/null && ruff --version &>/dev/null); then\n    pip install -r requirements-lint.txt\nfi\n\nYAPF_VERSION=$(yapf --version | awk '{print $2}')\nRUFF_VERSION=$(ruff --version | awk '{print $2}')\n\necho 'yapf: Check Start'\n\nYAPF_FLAGS=(\n    '--recursive'\n    '--parallel'\n)\n\nYAPF_EXCLUDES=(\n    '--exclude' 'build/**'\n)\n\n# Format specified files\nformat() {\n    yapf --in-place \"${YAPF_FLAGS[@]}\" \"$@\"\n}\n\n# Format all files\nformat_all() {\n    yapf --in-place \"${YAPF_FLAGS[@]}\" \"${YAPF_EXCLUDES[@]}\" .\n}\n\n# Format files that differ from main branch\nformat_changed() {\n    # The `if` guard ensures that the list of filenames is not empty, which\n    # could cause ruff to receive 0 positional arguments, making it hang\n    # waiting for STDIN.\n    #\n    # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that\n    # exist on both branches.\n    if git show-ref --verify --quiet refs/remotes/origin/main; then\n        BASE_BRANCH=\"origin/main\"\n    else\n        BASE_BRANCH=\"main\"\n    fi\n\n    MERGEBASE=\"$(git merge-base $BASE_BRANCH HEAD)\"\n\n    if ! git diff --diff-filter=ACM --quiet --exit-code \"$MERGEBASE\" -- '*.py' '*.pyi' &>/dev/null; then\n        git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.py' '*.pyi' | xargs -P 5 \\\n             yapf --in-place \"${YAPF_EXCLUDES[@]}\" \"${YAPF_FLAGS[@]}\"\n    fi\n}\n\n# If `--all` is passed, then any further arguments are ignored and the\n# entire python directory is formatted.\nif [[ \"$1\" == '--all' ]]; then\n   format_all\nelse\n   # Format only the files that changed in last commit.\n   format_changed\nfi\necho 'yapf: Done'\n\necho 'ruff: Check Start'\n# Lint specified files\nlint() {\n    ruff check \"$@\"\n}\n\n# Lint files that differ from main branch. Ignores dirs that are not slated\n# for autolint yet.\nlint_changed() {\n    # The `if` guard ensures that the list of filenames is not empty, which\n    # could cause ruff to receive 0 positional arguments, making it hang\n    # waiting for STDIN.\n    #\n    # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that\n    # exist on both branches.\n    if git show-ref --verify --quiet refs/remotes/origin/main; then\n        BASE_BRANCH=\"origin/main\"\n    else\n        BASE_BRANCH=\"main\"\n    fi\n\n    MERGEBASE=\"$(git merge-base $BASE_BRANCH HEAD)\"\n\n    if ! git diff --diff-filter=ACM --quiet --exit-code \"$MERGEBASE\" -- '*.py' '*.pyi' &>/dev/null; then\n        git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.py' '*.pyi' | xargs \\\n             ruff check\n    fi\n}\n\n# Run Ruff\n# If `--all` is passed, then any further arguments are ignored and the\n# entire python directory is linted.\nif [[ \"$1\" == '--all' ]]; then\n   lint . \nelse\n   # Check only the files that changed in last commit.\n   lint_changed\nfi\n\necho 'ruff: Done'\n\n# # params: tool name, tool version, required version\ntool_version_check() {\n    if [[ $2 != $3 ]]; then\n        echo \"Wrong $1 version installed: $3 is required, not $2.\"\n        pip install -r requirements-lint.txt\n    fi\n}\n\necho 'clang-format: Check Start'\n# If clang-format is available, run it; otherwise, skip\nif command -v clang-format &>/dev/null; then\n    CLANG_FORMAT_VERSION=$(clang-format --version | awk '{print $3}')\n    tool_version_check \"clang-format\" \"$CLANG_FORMAT_VERSION\" \"$(grep clang-format requirements-lint.txt | cut -d'=' -f3)\"\n\n    CLANG_FORMAT_FLAGS=(\"-i\")\n\n    # Format all C/C++ files in the repo, excluding specified directories\n    clang_format_all() {\n        # Replace \"#pragma unroll\" by \"// #pragma unroll\"\n        find . -type f \\( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' -o -name '*.cu' -o -name '*.cuh' \\) \\\n            -not -path \"./build/*\" \\\n            -exec perl -pi -e 's/#pragma unroll/\\/\\/#pragma unroll/g' {} +\n        find . -type f \\( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' -o -name '*.cu' -o -name '*.cuh' \\) \\\n            -not -path \"./build/*\" \\\n            -exec clang-format -i {} +\n        # Replace \"// #pragma unroll\" by \"#pragma unroll\"\n        find . -type f \\( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' -o -name '*.cu' -o -name '*.cuh' \\) \\\n            -not -path \"./build/*\" \\\n            -exec perl -pi -e 's/\\/\\/ *#pragma unroll/#pragma unroll/g' {} +\n    }\n\n    # Format changed C/C++ files relative to main\n    clang_format_changed() {\n        if git show-ref --verify --quiet refs/remotes/origin/main; then\n            BASE_BRANCH=\"origin/main\"\n        else\n            BASE_BRANCH=\"main\"\n        fi\n\n        MERGEBASE=\"$(git merge-base $BASE_BRANCH HEAD)\"\n\n        if ! git diff --diff-filter=ACM --quiet --exit-code \"$MERGEBASE\" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' '*.cu' '*.cuh' &>/dev/null; then\n            git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' '*.cu' '*.cuh' | xargs perl -pi -e 's/#pragma unroll/\\/\\/#pragma unroll/g'\n            git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' '*.cu' '*.cuh' | xargs clang-format -i\n            git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' '*.cu' '*.cuh' | xargs perl -pi -e 's/\\/\\/ *#pragma unroll/#pragma unroll/g'\n        fi\n    }\n\n    if [[ \"$1\" == '--all' ]]; then\n       # If --all is given, format all eligible C/C++ files\n       clang_format_all\n    else\n       # Otherwise, format only changed C/C++ files\n       clang_format_changed\n    fi\nelse\n    echo \"clang-format not found. Skipping C/C++ formatting.\"\nfi\necho 'clang-format: Done'\n\n# Check if there are any uncommitted changes after all formatting steps.\n# If there are, ask the user to review and stage them.\nif ! git diff --quiet &>/dev/null; then\n    echo 'Reformatted files. Please review and stage the changes.'\n    echo 'Changes not staged for commit:'\n    echo\n    git --no-pager diff --name-only\n\n    echo 'You can also copy-paste the diff below to fix the lint:'\n    echo\n    git --no-pager diff\n\n    exit 1\nfi\n\necho 'All checks passed'"
  },
  {
    "path": "install.sh",
    "content": "# Change current directory into project root\noriginal_dir=$(pwd)\nscript_dir=$(dirname \"$0\")\ncd \"$script_dir\"\n\n# Remove old dist file, build, and install\nrm -rf dist\npython setup.py bdist_wheel\npip install dist/*.whl\n\n# Open users' original directory\ncd \"$original_dir\"\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.yapf]\nbased_on_style = \"pep8\"\ncolumn_limit = 140\nindent_width = 4\n\n[tool.ruff.lint]\nselect = [\n    # pycodestyle\n    \"E\", \"W\",\n    # Pyflakes\n    \"F\",\n    # pyupgrade\n    # \"UP\",\n    # flake8-bugbear\n    \"B\",\n    # flake8-simplify\n    \"SIM\",\n    # isort\n    # \"I\",\n]\nignore = [\n    # Module level import not at top of file\n    \"E402\",\n    # star imports\n    \"F405\", \"F403\",\n    # ambiguous name\n    \"E741\",\n    # line too long\n    \"E501\",\n    # key in dict.keys()\n    \"SIM118\",\n    # memory leaks\n    \"B019\",\n    # No such file or directory\n    \"E902\",\n]\nexclude = [\n    \"deep_ep/__init__.py\"\n]"
  },
  {
    "path": "requirements-lint.txt",
    "content": "clang-format==15.0.7\nyapf==0.40.2\nruff==0.6.5"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport subprocess\nimport setuptools\nimport importlib\n\nfrom pathlib import Path\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\n# Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X`\ndef get_nvshmem_host_lib_name(base_dir):\n    path = Path(base_dir).joinpath('lib')\n    for file in path.rglob('libnvshmem_host.so.*'):\n        return file.name\n    raise ModuleNotFoundError('libnvshmem_host.so not found')\n\n\nif __name__ == '__main__':\n    disable_nvshmem = False\n    nvshmem_dir = os.getenv('NVSHMEM_DIR', None)\n    nvshmem_host_lib = 'libnvshmem_host.so'\n    if nvshmem_dir is None:\n        try:\n            nvshmem_dir = importlib.util.find_spec(\"nvidia.nvshmem\").submodule_search_locations[0]\n            nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir)\n            import nvidia.nvshmem as nvshmem  # noqa: F401\n        except (ModuleNotFoundError, AttributeError, IndexError):\n            print(\n                'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\\n'\n            )\n            disable_nvshmem = True\n    else:\n        disable_nvshmem = False\n\n    if not disable_nvshmem:\n        assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'\n\n    cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']\n    nvcc_flags = ['-O3', '-Xcompiler', '-O3']\n    sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu']\n    include_dirs = ['csrc/']\n    library_dirs = []\n    nvcc_dlink = []\n    extra_link_args = ['-lcuda']\n\n    # NVSHMEM flags\n    if disable_nvshmem:\n        cxx_flags.append('-DDISABLE_NVSHMEM')\n        nvcc_flags.append('-DDISABLE_NVSHMEM')\n    else:\n        sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])\n        include_dirs.extend([f'{nvshmem_dir}/include'])\n        library_dirs.extend([f'{nvshmem_dir}/lib'])\n        nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])\n        extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])\n\n    if int(os.getenv('DISABLE_SM90_FEATURES', 0)):\n        # Prefer A100\n        os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0')\n\n        # Disable some SM90 features: FP8, launch methods, and TMA\n        cxx_flags.append('-DDISABLE_SM90_FEATURES')\n        nvcc_flags.append('-DDISABLE_SM90_FEATURES')\n\n        # Disable internode and low-latency kernels\n        assert disable_nvshmem\n    else:\n        # Prefer H800 series\n        os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')\n\n        # CUDA 12 flags\n        nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])\n\n    # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`\n    if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0':\n        assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1\n        os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'\n\n    # Disable aggressive PTX instructions\n    if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):\n        cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')\n        nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')\n\n    # Bits of `topk_idx.dtype`, choices are 32 and 64\n    if \"TOPK_IDX_BITS\" in os.environ:\n        topk_idx_bits = int(os.environ['TOPK_IDX_BITS'])\n        cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')\n        nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')\n\n    # Put them together\n    extra_compile_args = {\n        'cxx': cxx_flags,\n        'nvcc': nvcc_flags,\n    }\n    if len(nvcc_dlink) > 0:\n        extra_compile_args['nvcc_dlink'] = nvcc_dlink\n\n    # Summary\n    print('Build summary:')\n    print(f' > Sources: {sources}')\n    print(f' > Includes: {include_dirs}')\n    print(f' > Libraries: {library_dirs}')\n    print(f' > Compilation flags: {extra_compile_args}')\n    print(f' > Link flags: {extra_link_args}')\n    print(f' > Arch list: {os.environ[\"TORCH_CUDA_ARCH_LIST\"]}')\n    print(f' > NVSHMEM path: {nvshmem_dir}')\n    print()\n\n    # noinspection PyBroadException\n    try:\n        cmd = ['git', 'rev-parse', '--short', 'HEAD']\n        revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()\n    except Exception as _:\n        revision = ''\n\n    setuptools.setup(name='deep_ep',\n                     version='1.2.1' + revision,\n                     packages=setuptools.find_packages(include=['deep_ep']),\n                     ext_modules=[\n                         CUDAExtension(name='deep_ep_cpp',\n                                       include_dirs=include_dirs,\n                                       library_dirs=library_dirs,\n                                       sources=sources,\n                                       extra_compile_args=extra_compile_args,\n                                       extra_link_args=extra_link_args)\n                     ],\n                     cmdclass={'build_ext': BuildExtension})\n"
  },
  {
    "path": "tests/test_internode.py",
    "content": "import argparse\nimport os\nimport time\nimport torch\nimport torch.distributed as dist\n\n# noinspection PyUnresolvedReferences\nimport deep_ep\nfrom utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor\n\n# Test compatibility with low latency functions\nimport test_low_latency\n\n\n# noinspection PyShadowingNames\ndef test_main(args: argparse.Namespace,\n              num_sms: int,\n              local_rank: int,\n              num_local_ranks: int,\n              num_ranks: int,\n              num_nodes: int,\n              rank: int,\n              buffer: deep_ep.Buffer,\n              group: dist.ProcessGroup,\n              skip_benchmark: bool = False):\n    # Settings\n    num_tokens, hidden = args.num_tokens, args.hidden\n    num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts\n\n    assert num_experts % num_ranks == 0 and num_local_ranks == 8\n    if local_rank == 0:\n        print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)\n\n    # Random data\n    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank\n    x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')\n    x_e4m3 = per_token_cast_to_fp8(x)\n    x_pure_rand_e4m3 = per_token_cast_to_fp8(x_pure_rand)\n    x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)\n    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1\n    group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)\n    group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices\n    masked_scores = create_grouped_scores(scores, group_idx, num_nodes)\n    topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]\n    topk_idx = topk_idx.to(deep_ep.topk_idx_t)\n    topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank\n    topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')\n    rank_idx = topk_idx // (num_experts // num_ranks)\n    rank_idx = rank_idx.to(torch.int64)\n    rank_idx.masked_fill_(topk_idx == -1, -1)\n    inplace_unique(rank_idx, num_ranks)\n    rdma_rank_idx = rank_idx // num_local_ranks\n    rdma_rank_idx.masked_fill_(rank_idx == -1, -1)\n    inplace_unique(rdma_rank_idx, num_nodes)\n    hash_value = 0\n\n    # RDMA dispatch counts\n    rdma_idx = topk_idx // (num_experts // num_nodes)\n    rdma_idx.masked_fill_(topk_idx == -1, -1)\n    inplace_unique(rdma_idx, num_nodes)\n    num_rdma_token_sent = rdma_idx.ne(-1).sum().item()\n\n    # Expert meta\n    num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')\n    for i in range(num_experts):\n        num_tokens_per_expert[i] = (topk_idx == i).sum()\n    gbl_num_tokens_per_expert = num_tokens_per_expert.clone()\n    dist.all_reduce(gbl_num_tokens_per_expert, group=group)\n\n    # Rank layout meta\n    num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')\n    num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')\n    token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')\n    for i in range(num_ranks):\n        num_tokens_per_rank[i] = (rank_idx == i).sum()\n        token_sel = (rank_idx == i).max(dim=-1)[0]\n        count = token_sel.sum().item()\n        tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]\n        tokens[:count] = torch.sort(tokens[:count])[0]\n        token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')\n    for i in range(num_nodes):\n        num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()\n    token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)\n    is_token_in_rank = token_idx_in_rank >= 0\n    gbl_num_tokens_per_rank = num_tokens_per_rank.clone()\n    dist.all_reduce(gbl_num_tokens_per_rank, group=group)\n\n    ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \\\n        buffer.get_dispatch_layout(topk_idx, num_experts)\n    assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)\n    assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)\n    assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)\n    assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)\n    t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]\n    if local_rank == 0:\n        print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)\n        print('', flush=True)\n    group.barrier()\n    time.sleep(1)\n\n    # Config\n    rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (24, 48, 96, 144, 160) else 512)\n    config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)\n\n    # Test dispatch\n    # noinspection PyShadowingNames\n    def check_data(check_x, recv_gbl_rank_prefix_sum):\n        assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))\n        check_start = 0\n        for i in range(num_ranks):\n            check_end = recv_gbl_rank_prefix_sum[i].item()\n            assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0\n            check_start = check_end\n\n    for previous_mode in (False, True):\n        for async_mode in (False, True):\n            for current_x in (x_pure_rand, x, x_pure_rand_e4m3, x_e4m3):\n                for with_topk in (False, True):\n                    is_rand = current_x is x_pure_rand or current_x is x_pure_rand_e4m3\n                    if local_rank == 0:\n                        print(\n                            f'[testing] Running with {\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}, {\"with\" if with_topk else \"without\"} top-k (async={async_mode}, previous={previous_mode}) ...',\n                            flush=True,\n                            end='')\n                    dispatch_args = {\n                        'x': current_x,\n                        'num_tokens_per_rank': num_tokens_per_rank,\n                        'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,\n                        'is_token_in_rank': is_token_in_rank,\n                        'num_tokens_per_expert': num_tokens_per_expert,\n                        'config': config,\n                        'async_finish': async_mode\n                    }\n                    if with_topk:\n                        dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if is_rand else topk_weights})\n                    if previous_mode:\n                        dispatch_args.update({'previous_event': buffer.capture()})\n                    recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(\n                        **dispatch_args)\n                    event.current_stream_wait() if async_mode else ()\n\n                    if current_x is x_pure_rand or current_x is x:\n                        hash_value += hash_tensor(recv_x)\n                    else:\n                        hash_value += hash_tensor(recv_x[0])\n                        hash_value += hash_tensor(recv_x[1])\n\n                    recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x\n\n                    # Checks\n                    recv_gbl_rank_prefix_sum = handle[-4]\n                    assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), \\\n                        f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'\n                    assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list\n                    if not is_rand:\n                        check_data(recv_x, recv_gbl_rank_prefix_sum)\n                    recv_topk_weights_clone = None\n                    if with_topk:\n                        # Check `topk_idx`\n                        assert (recv_topk_idx.eq(-1) |\n                                ((recv_topk_idx >= 0) &\n                                 (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()\n                        for i, count in enumerate(recv_num_tokens_per_expert_list):\n                            assert recv_topk_idx.eq(i).sum().item() == count\n\n                        # Check `topk_weights`\n                        recv_topk_weights_clone = recv_topk_weights.clone()\n                        if not is_rand:\n                            recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(\n                                dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]\n                            check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)\n\n                    # Test `num_worst_tokens != 0`\n                    if with_topk:\n                        num_worst_tokens = num_tokens * num_ranks\n                        dispatch_args.update({'num_worst_tokens': num_worst_tokens})\n                        recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)\n                        event.current_stream_wait() if async_mode else ()\n                        recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x\n                        assert len(empty_list) == 0\n                        assert num_worst_tokens == recv_worst_x.size(0)\n                        assert num_worst_tokens == recv_worst_topk_idx.size(0)\n                        assert num_worst_tokens == recv_worst_topk_weights.size(0)\n                        assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])\n                        assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])\n                        assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])\n                        assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()\n\n                    # Test cached dispatch (must without top-k staffs)\n                    if not with_topk:\n                        dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}\n                        if previous_mode:\n                            dispatch_args.update({'previous_event': buffer.capture()})\n                        recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)\n                        event.current_stream_wait() if async_mode else ()\n                        recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x\n                        if not is_rand:\n                            check_data(recv_x, recv_gbl_rank_prefix_sum)\n\n                    # Test combine\n                    bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')\n                    bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')\n                    combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}\n                    if with_topk:\n                        combine_args.update({'topk_weights': recv_topk_weights})\n                    if previous_mode:\n                        combine_args.update({'previous_event': buffer.capture()})\n                    combined_x, combined_topk_weights, event = buffer.combine(**combine_args)\n                    event.current_stream_wait() if async_mode else ()\n                    check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)\n                    ref_x = x_pure_rand if is_rand else x\n                    assert calc_diff(check_x, ref_x) < 5e-4 if current_x is x_pure_rand_e4m3 else 5e-6\n                    if with_topk:\n                        check_topk_weights = combined_topk_weights if is_rand else (combined_topk_weights /\n                                                                                    is_token_in_rank.sum(dim=1).unsqueeze(1))\n                        ref_topk_weights = topk_weights_pure_rand if is_rand else topk_weights\n                        assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9\n\n                    hash_value += hash_tensor(recv_x)\n\n                    # For later tuning\n                    dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2\n                    dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2\n                    combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes\n                    combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes\n\n                    if local_rank == 0:\n                        print(' passed', flush=True)\n    if local_rank == 0:\n        print('', flush=True)\n\n    if skip_benchmark:\n        return hash_value\n\n    # Tune dispatch performance\n    best_dispatch_results = None\n    fp8_factor = (1 + 4 / 128) / 2\n    for current_x in (x_e4m3, x):\n        best_time, best_results = 1e10, None\n        rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes\n        nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes\n        for nvl_chunk_size in range(4, 45, 4):\n            for rdma_chunk_size in range(4, 33, 4):\n                config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)\n                tune_args = {'x': current_x, 'handle': handle, 'config': config}\n                t, notify_t = bench_kineto(\n                    lambda: buffer.dispatch(**tune_args),  # noqa: B023\n                    ('dispatch', 'notify'),\n                    suppress_kineto_output=True)\n                if t < best_time:\n                    best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)\n                if local_rank == 0:\n                    print(\n                        f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: '\n                        f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, '\n                        f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ',\n                        flush=True)\n        if local_rank == 0:\n            print(\n                f'[tuning] Best dispatch ({\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: '\n                f'{best_results[3] * 1e6:.0f} + {best_time * 1e6:.0f} us, '\n                f'{rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',\n                flush=True)\n            print('', flush=True)\n\n        if isinstance(current_x, tuple):\n            # Gather FP8 the best config from rank 0\n            best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')\n            all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]\n            dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)\n            best_dispatch_results = all_best_fp8_results_list[0].tolist()\n    dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2],\n                                     rdma_buffer_size)\n\n    dispatch_args = {\n        'x': x,\n        'num_tokens_per_rank': num_tokens_per_rank,\n        'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,\n        'is_token_in_rank': is_token_in_rank,\n        'num_tokens_per_expert': num_tokens_per_expert,\n        'config': dispatch_config if dispatch_config is not None else config\n    }\n    recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)\n\n    # Tune combine performance\n    best_time, best_results = 1e10, None\n    for nvl_chunk_size in range(1, 8, 1):\n        for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):\n            config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)\n            tune_args = {'x': recv_x, 'handle': handle, 'config': config}\n            t, notify_t = bench_kineto(\n                lambda: buffer.combine(**tune_args),  # noqa: B023\n                ('combine', 'notify'),\n                suppress_kineto_output=True)\n            if local_rank == 0:\n                print(\n                    f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: '\n                    f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, '\n                    f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), '\n                    f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ',\n                    flush=True)\n                if t < best_time:\n                    best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)\n\n    if local_rank == 0:\n        print(\n            f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, '\n            f'{best_results[3] * 1e6:.2f} + {best_time * 1e6:.2f} us, '\n            f'{combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)',\n            flush=True)\n        print('', flush=True)\n    return hash_value\n\n\n# noinspection PyUnboundLocalVariable,PyShadowingNames\ndef test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):\n    num_nodes = int(os.getenv('WORLD_SIZE', 1))\n    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)\n    if args.test_ll_compatibility:\n        ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9\n\n    num_sms = 24\n    num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)\n\n    buffer = deep_ep.Buffer(group,\n                            int(2e9),\n                            int(1e9),\n                            low_latency_mode=args.test_ll_compatibility,\n                            num_qps_per_rank=num_qps_per_rank,\n                            explicitly_destroy=True)\n    assert num_local_ranks == 8 and num_ranks > 8\n\n    for seed in range(int(1e9)):\n        if local_rank == 0:\n            print(f'Testing with seed {seed} ...', flush=True)\n        torch.manual_seed(rank + seed)\n        ref_hash = 0\n        for i in (num_sms, ):\n            ref_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group,\n                                  args.pressure_test_mode == 1)\n            if local_rank == 0:\n                print('', flush=True)\n        if args.pressure_test_mode == 0:\n            break\n\n        if local_rank == 0:\n            print(f'{ref_hash=}')\n            print('', flush=True)\n\n        for _ in range(20):\n            torch.manual_seed(rank + seed)\n            current_hash = 0\n            for i in (num_sms, ):\n                current_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group,\n                                          args.pressure_test_mode == 1)\n                if local_rank == 0:\n                    print('', flush=True)\n            assert current_hash == ref_hash\n\n    # Test compatibility with low latency functions\n    if args.test_ll_compatibility:\n        buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)\n        test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)\n\n    # Destroy the buffer runtime and communication group\n    buffer.destroy()\n    dist.barrier()\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Test internode EP kernels')\n    parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')\n    parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)')\n    parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')\n    parser.add_argument('--num-topk-groups', type=int, default=None, help='Number of top-k groups (default: `min(num_nodes, 4)`)')\n    parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')\n    parser.add_argument(\n        '--pressure-test-mode',\n        type=int,\n        default=0,\n        help='Pressure test mode. 0: don\\'t do pressure test, 1: do pressure test without benchmarks, 2: do pressure test with benchmarks')\n    parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256')\n    parser.add_argument('--test-ll-compatibility', action='store_true', help='whether to test compatibility with low-latency kernels')\n    args = parser.parse_args()\n\n    # Set default `num_topk_groups` if not provided\n    if args.num_topk_groups is None:\n        num_nodes = int(os.getenv('WORLD_SIZE', 1))\n        args.num_topk_groups = min(num_nodes, 4)\n\n    num_processes = args.num_processes\n    torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)\n"
  },
  {
    "path": "tests/test_intranode.py",
    "content": "import argparse\nimport time\nimport torch\nimport torch.distributed as dist\n\n# noinspection PyUnresolvedReferences\nimport deep_ep\nfrom utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back\n\n# Test compatibility with low latency functions\nimport test_low_latency\n\n\n# noinspection PyShadowingNames\ndef test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer,\n              group: dist.ProcessGroup):\n    # Settings\n    num_tokens, hidden = args.num_tokens, args.hidden\n    num_topk, num_experts = args.num_topk, args.num_experts\n\n    assert num_experts % num_ranks == 0\n    if local_rank == 0:\n        print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)\n\n    # Random data\n    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank\n    x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')\n    x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None\n    x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None\n    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1\n    topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]\n    topk_idx = topk_idx.to(deep_ep.topk_idx_t)\n    topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank\n    topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')\n    rank_idx = topk_idx // (num_experts // num_ranks)\n    rank_idx = rank_idx.to(torch.int64)\n    rank_idx.masked_fill_(topk_idx == -1, -1)\n    inplace_unique(rank_idx, num_ranks)\n\n    # Expert meta\n    num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')\n    for i in range(num_experts):\n        num_tokens_per_expert[i] = (topk_idx == i).sum()\n    gbl_num_tokens_per_expert = num_tokens_per_expert.clone()\n    dist.all_reduce(gbl_num_tokens_per_expert, group=group)\n\n    # Rank layout meta\n    num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')\n    token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')\n    for i in range(num_ranks):\n        num_tokens_per_rank[i] = (rank_idx == i).sum()\n        token_sel = (rank_idx == i).max(dim=-1)[0]\n        count = token_sel.sum().item()\n        tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]\n        tokens[:count] = torch.sort(tokens[:count])[0]\n        token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')\n    token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)\n    is_token_in_rank = token_idx_in_rank >= 0\n    gbl_num_tokens_per_rank = num_tokens_per_rank.clone()\n    dist.all_reduce(gbl_num_tokens_per_rank, group=group)\n\n    ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \\\n        buffer.get_dispatch_layout(topk_idx, num_experts)\n    assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)\n    assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)\n    assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)\n    t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]\n    if local_rank == 0:\n        print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)\n        print('', flush=True)\n    group.barrier()\n    time.sleep(1)\n\n    # Config\n    nvl_buffer_size = 256\n    config = deep_ep.Config(num_sms, 8, nvl_buffer_size)\n\n    # Test dispatch\n    # noinspection PyShadowingNames\n    def check_data(check_x, rank_prefix_matrix):\n        assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))\n        check_start = 0\n        for i in range(num_ranks):\n            check_end = rank_prefix_matrix[i][rank].item()\n            assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0\n            check_start = check_end\n\n    for previous_mode in (False, True):\n        for async_mode in (False, True):\n            for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)):\n                for with_topk in (False, True):\n                    if local_rank == 0:\n                        print(\n                            f'[testing] Running with {\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}, {\"with\" if with_topk else \"without\"} top-k (async={async_mode}, previous={previous_mode}) ...',\n                            flush=True,\n                            end='')\n                    dispatch_args = {\n                        'x': current_x,\n                        'num_tokens_per_rank': num_tokens_per_rank,\n                        'is_token_in_rank': is_token_in_rank,\n                        'num_tokens_per_expert': num_tokens_per_expert,\n                        'config': config,\n                        'async_finish': async_mode\n                    }\n                    if with_topk:\n                        dispatch_args.update({\n                            'topk_idx': topk_idx,\n                            'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights\n                        })\n                    if previous_mode:\n                        dispatch_args.update({'previous_event': buffer.capture()})\n                    recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(\n                        **dispatch_args)\n                    event.current_stream_wait() if async_mode else ()\n                    recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x\n\n                    # Checks\n                    rank_prefix_matrix = handle[0]\n                    assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(\n                        0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'\n                    assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list\n                    if current_x is not x_pure_rand:\n                        check_data(recv_x, rank_prefix_matrix)\n                    recv_topk_weights_clone = None\n                    if with_topk:\n                        # Check `topk_idx`\n                        assert (recv_topk_idx.eq(-1) |\n                                ((recv_topk_idx >= 0) &\n                                 (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()\n                        for i, count in enumerate(recv_num_tokens_per_expert_list):\n                            assert recv_topk_idx.eq(i).sum().item() == count\n\n                        # Check `topk_weights`\n                        recv_topk_weights_clone = recv_topk_weights.clone()\n                        if current_x is not x_pure_rand:\n                            recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(\n                                dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]\n                            check_data(recv_topk_weights, rank_prefix_matrix)\n\n                    # Test `num_worst_tokens != 0`\n                    if with_topk:\n                        num_worst_tokens = num_tokens * num_ranks\n                        dispatch_args.update({'num_worst_tokens': num_worst_tokens})\n                        recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)\n                        event.current_stream_wait() if async_mode else ()\n                        recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x\n                        assert len(empty_list) == 0\n                        assert num_worst_tokens == recv_worst_x.size(0)\n                        assert num_worst_tokens == recv_worst_topk_idx.size(0)\n                        assert num_worst_tokens == recv_worst_topk_weights.size(0)\n                        assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])\n                        assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])\n                        assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])\n                        assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()\n\n                    # Test cached dispatch (must without top-k staffs)\n                    if not with_topk:\n                        dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}\n                        if previous_mode:\n                            dispatch_args.update({'previous_event': buffer.capture()})\n                        recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)\n                        event.current_stream_wait() if async_mode else ()\n                        recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x\n                        if current_x is not x_pure_rand:\n                            check_data(recv_x, rank_prefix_matrix)\n\n                    # Test combine\n                    combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}\n                    if with_topk:\n                        combine_args.update({'topk_weights': recv_topk_weights})\n                    if previous_mode:\n                        combine_args.update({'previous_event': buffer.capture()})\n                    combined_x, combined_topk_weights, event = buffer.combine(**combine_args)\n                    event.current_stream_wait() if async_mode else ()\n                    check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)\n                    ref_x = x_pure_rand if current_x is x_pure_rand else x\n                    assert calc_diff(check_x, ref_x) < 5e-6\n                    if with_topk:\n                        check_topk_weights = combined_topk_weights if (current_x\n                                                                       is x_pure_rand) else (combined_topk_weights /\n                                                                                             is_token_in_rank.sum(dim=1).unsqueeze(1))\n                        ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights\n                        assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9\n\n                    # For later tuning\n                    dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2\n                    combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes\n\n                    if local_rank == 0:\n                        print(' passed', flush=True)\n    if local_rank == 0:\n        print('', flush=True)\n\n    # Tune dispatch performance\n    best_dispatch_results = None\n    fp8_factor = (1 + 4 / 128) / 2\n    for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)):\n        best_time, best_results = 1e10, None\n        nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes\n        for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):\n            if nvl_chunk_size > 0:\n                config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)\n            else:\n                # Test default config as well\n                deep_ep.Buffer.set_num_sms(num_sms)\n                config = deep_ep.Buffer.get_dispatch_config(num_ranks)\n            tune_args = {'x': current_x, 'handle': handle, 'config': config}\n            t = bench(lambda: buffer.dispatch(**tune_args))[0]  # noqa: B023\n            if t < best_time and nvl_chunk_size > 0:\n                best_time, best_results = t, (num_sms, nvl_chunk_size)\n            if local_rank == 0:\n                print(\n                    f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else \"default\"}: '\n                    f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us',\n                    flush=True)\n        if local_rank == 0:\n            print(\n                f'[tuning] Best dispatch ({\"FP8\" if isinstance(current_x, tuple) else \"BF16\"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us',\n                flush=True)\n            print('', flush=True)\n\n        # Gather the best config from rank 0 and the first test setting\n        if best_dispatch_results is None:\n            best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')\n            all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]\n            dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)\n            best_dispatch_results = all_best_fp8_results_list[0].tolist()\n    dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)\n\n    dispatch_args = {\n        'x': x,\n        'num_tokens_per_rank': num_tokens_per_rank,\n        'is_token_in_rank': is_token_in_rank,\n        'num_tokens_per_expert': num_tokens_per_expert,\n        'config': dispatch_config if dispatch_config is not None else config\n    }\n    recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)\n\n    # Tune combine performance\n    best_time, best_results = 1e10, None\n    for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):\n        if nvl_chunk_size > 0:\n            config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)\n        else:\n            # Test default config as well\n            deep_ep.Buffer.set_num_sms(num_sms)\n            config = deep_ep.Buffer.get_combine_config(num_ranks)\n        tune_args = {'x': recv_x, 'handle': handle, 'config': config}\n        t = bench(lambda: buffer.combine(**tune_args))[0]  # noqa: B023\n        if local_rank == 0:\n            print(\n                f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else \"default\"}: '\n                f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us',\n                flush=True)\n            if t < best_time and nvl_chunk_size > 0:\n                best_time, best_results = t, (num_sms, nvl_chunk_size)\n\n    if local_rank == 0:\n        print(\n            f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us',\n            flush=True)\n        print('', flush=True)\n\n\n# noinspection PyUnboundLocalVariable,PyShadowingNames\ndef test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):\n    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)\n    test_ll_compatibility, num_rdma_bytes = False, 0\n    if test_ll_compatibility:\n        ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9\n        num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)\n\n    buffer = deep_ep.Buffer(group,\n                            int(2e9),\n                            num_rdma_bytes,\n                            low_latency_mode=test_ll_compatibility,\n                            num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),\n                            explicitly_destroy=True,\n                            allow_mnnvl=args.allow_mnnvl,\n                            use_fabric=args.use_fabric)\n    torch.manual_seed(rank)\n\n    for i in (24, ):\n        test_main(args, i, local_rank, num_ranks, rank, buffer, group)\n        if local_rank == 0:\n            print('', flush=True)\n\n    # Test compatibility with low latency functions\n    if test_ll_compatibility:\n        buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)\n        test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)\n\n    # Destroy the buffer runtime and communication group\n    buffer.destroy()\n    dist.barrier()\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Test intranode EP kernels')\n    parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')\n    parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)')\n    parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')\n    parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')\n    parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256)')\n    parser.add_argument('--allow-mnnvl', action=\"store_true\", help='Enable MNNVL support')\n    parser.add_argument('--use-fabric', action=\"store_true\", help='Enable fabric mode')\n    args = parser.parse_args()\n\n    num_processes = args.num_processes\n    torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)\n"
  },
  {
    "path": "tests/test_low_latency.py",
    "content": "import argparse\nimport random\nimport torch\nimport torch.distributed as dist\nfrom functools import partial\nfrom typing import Literal, Set\n\nimport deep_ep\nfrom utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back\n\n\ndef simulate_failure_and_skip(rank: int, api: Literal[\"dispatch\", \"combine\", \"clean\"], expected_masked_ranks: Set[int]):\n    # Simulates rank failure when the rank first calls the corresponding communication API\n    failed_api_ranks = {\n        # API -> rank to fail (rank fails when it first calls the corresponding communication API)\n        'dispatch': 1,\n        'combine': 3,\n        'clean': 5\n    }\n    if rank in expected_masked_ranks:\n        # Rank already failed\n        return True\n    if api in failed_api_ranks.keys():\n        expected_masked_ranks.add(failed_api_ranks[api])\n        if failed_api_ranks[api] == rank:\n            print(f\"Rank {rank} failed when first calling {api} communication API, exit...\", flush=True)\n            return True\n    return False\n\n\ndef query_mask_buffer_and_check(api: Literal[\"dispatch\", \"combine\", \"clean\"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,\n                                expected_masked_ranks: Set[int]):\n    buffer.low_latency_query_mask_buffer(mask_status)\n    assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks\n\n\ndef test_main(num_tokens: int,\n              hidden: int,\n              num_experts: int,\n              num_topk: int,\n              rank: int,\n              num_ranks: int,\n              group: dist.ProcessGroup,\n              buffer: deep_ep.Buffer,\n              use_logfmt: bool = False,\n              shrink_test: bool = False,\n              seed: int = 0):\n    torch.manual_seed(seed + rank)\n    random.seed(seed + rank)\n\n    assert num_experts % num_ranks == 0\n    num_local_experts = num_experts // num_ranks\n\n    # NOTES: the integers greater than 256 exceed the BF16 precision limit\n    rank_offset = 128\n    assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'\n\n    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)\n    x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)\n    x_list = [x]\n    for _ in range(4 if use_logfmt else 0):\n        # NOTES: make more LogFMT casts and also with some BF16\n        x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())\n    # NOTES: the last one is for performance testing\n    # Most of the values in the perf case is lower than the threshold, casting most channels\n    x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)\n\n    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1\n    topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]\n    topk_idx = topk_idx.to(deep_ep.topk_idx_t)\n    topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()\n\n    # Randomly mask some positions\n    for _ in range(10):\n        topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1\n\n    all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')\n    dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)\n\n    # For failure simulation and shrink testing\n    mask_status = torch.zeros((num_ranks, ), dtype=torch.int, device='cuda')\n    expected_masked_ranks = set()\n\n    # Check dispatch correctness\n    do_check = True\n    hash_value, num_times = 0, 0\n    for current_x in x_list:\n        for return_recv_hook in (False, True):\n            for dispatch_use_fp8 in (False, True):\n                for round_scale in (False, True) if dispatch_use_fp8 else (False, ):\n                    for use_ue8m0 in (False, True) if round_scale else (False, ):\n                        if shrink_test and simulate_failure_and_skip(rank, \"dispatch\", expected_masked_ranks):\n                            break\n                        num_times += 1\n                        for _ in range((num_times % 2) + 1):\n                            cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')\n                            packed_recv_x, packed_recv_count, handle, event, hook = \\\n                                buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,\n                                                            use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,\n                                                            cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,\n                                                            async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)\n                            hook() if return_recv_hook else event.current_stream_wait()\n                        if shrink_test:\n                            query_mask_buffer_and_check(\"dispatch\", buffer, mask_status, expected_masked_ranks)\n                        packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x\n                        simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \\\n                            if dispatch_use_fp8 else packed_recv_x.clone()\n                        for i in range(num_local_experts if do_check else 0):\n                            expert_id = rank * num_local_experts + i\n                            recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]\n                            recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]\n\n                            # Check expert indices\n                            int_mask = (2**32) - 1\n                            num_valid_tokens = recv_count.item()\n                            assert cumulative_local_expert_recv_stats[i].item(\n                            ) == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'\n                            assert num_valid_tokens == (\n                                recv_layout_range\n                                & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'\n                            assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(\n                            ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status==0].sum().item()}'\n\n                            if num_valid_tokens == 0:\n                                continue\n                            # Check received data\n                            if current_x is x:\n                                recv_x = recv_x[:num_valid_tokens]\n                                recv_x_amin = recv_x[:, :-128].amin(dim=-1)\n                                recv_src_info = recv_src_info[:num_valid_tokens]\n                                assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))\n                                if round_scale:\n                                    assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007\n                                else:\n                                    assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0\n                                for j in range(num_ranks):\n                                    if shrink_test and mask_status[j]:\n                                        continue\n                                    begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()\n                                    if not round_scale:\n                                        assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()\n                                        assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0\n                            if dispatch_use_fp8:\n                                hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])\n                                hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])\n                            else:\n                                hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])\n\n                        # Check combine correctness\n                        if shrink_test and simulate_failure_and_skip(rank, \"combine\", expected_masked_ranks):\n                            break\n                        for zero_copy in (False, ) if use_logfmt else (False, True):\n                            if zero_copy:\n                                buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x\n                            out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')\n                            combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,\n                                                                                 topk_idx,\n                                                                                 topk_weights,\n                                                                                 handle,\n                                                                                 use_logfmt=use_logfmt,\n                                                                                 async_finish=not return_recv_hook,\n                                                                                 zero_copy=zero_copy,\n                                                                                 return_recv_hook=return_recv_hook,\n                                                                                 out=out)\n                            hook() if return_recv_hook else event.current_stream_wait()\n                            if shrink_test:\n                                query_mask_buffer_and_check(\"combine\", buffer, mask_status, expected_masked_ranks)\n                            if do_check:\n                                if shrink_test:\n                                    owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts)\n                                    fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert)\n                                    valid_topk_idx = topk_idx >= 0\n                                    failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool)\n                                    failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx])\n                                    topk_idx[failed_topk_idx] = -1\n                                diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)\n                                assert torch.isnan(combined_x).sum().item() == 0\n                                if not round_scale:\n                                    assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'\n                                hash_value ^= hash_tensor(combined_x)\n\n                        # Clean buffer API\n                        if shrink_test:\n                            if simulate_failure_and_skip(rank, \"clean\", expected_masked_ranks):\n                                break\n\n                            buffer.clean_low_latency_buffer(num_tokens, hidden, num_experts)\n                            query_mask_buffer_and_check(\"clean\", buffer, mask_status, expected_masked_ranks)\n\n    if shrink_test:\n        return\n\n    # noinspection PyShadowingNames\n    def large_gemm_with_hook(hook):\n        mat_0 = torch.randn((8192, 8192), dtype=torch.float)\n        mat_1 = torch.randn((8192, 8192), dtype=torch.float)\n        mat_0 @ mat_1\n        hook()\n\n    # noinspection PyShadowingNames\n    def test_func(return_recv_hook: bool):\n        recv_x, recv_count, handle, event, hook = \\\n            buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,\n                                        cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,\n                                        use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)\n        large_gemm_with_hook(hook) if return_recv_hook else None\n        combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,\n                                                             topk_idx,\n                                                             topk_weights,\n                                                             handle,\n                                                             use_logfmt=use_logfmt,\n                                                             return_recv_hook=return_recv_hook)\n        large_gemm_with_hook(hook) if return_recv_hook else None\n\n    # Calculate bandwidth\n    num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2\n    num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4\n    num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0\n    for i in range(num_tokens):\n        num_selections = (topk_idx[i] != -1).sum().item()\n        num_dispatch_comm_bytes += num_fp8_bytes * num_selections\n        num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections\n\n    # Dispatch + combine testing\n    avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))\n    print(\n        f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '\n        f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',\n        flush=True)\n\n    # Separate profiling\n    for return_recv_hook in (False, True):\n        group.barrier()\n        dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),\n                                             kernel_names=('dispatch', 'combine'),\n                                             barrier_comm_profiling=True,\n                                             suppress_kineto_output=True,\n                                             num_kernels_per_period=2 if return_recv_hook else 1)\n        if not return_recv_hook:\n            print(\n                f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '\n                f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',\n                flush=True)\n        else:\n            print(\n                f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '\n                f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',\n                flush=True)\n    return hash_value\n\n\n# noinspection PyUnboundLocalVariable,PyShadowingNames\ndef test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):\n    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)\n    num_tokens, hidden = args.num_tokens, args.hidden\n    num_topk, num_experts = args.num_topk, args.num_experts\n\n    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)\n    if local_rank == 0:\n        print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)\n    buffer = deep_ep.Buffer(group,\n                            num_rdma_bytes=num_rdma_bytes,\n                            low_latency_mode=True,\n                            num_qps_per_rank=num_experts // num_ranks,\n                            allow_nvlink_for_low_latency_mode=not args.disable_nvlink,\n                            explicitly_destroy=True,\n                            allow_mnnvl=args.allow_mnnvl,\n                            enable_shrink=args.shrink_test)\n    test_main(num_tokens,\n              hidden,\n              num_experts,\n              num_topk,\n              rank,\n              num_ranks,\n              group,\n              buffer,\n              use_logfmt=args.use_logfmt,\n              shrink_test=args.shrink_test,\n              seed=1)\n\n    do_pressure_test = args.pressure_test\n    for seed in range(int(1e9) if do_pressure_test else 0):\n        if local_rank == 0:\n            print(f'Testing with seed {seed} ...', flush=True)\n        ref_hash = test_main(num_tokens,\n                             hidden,\n                             num_experts,\n                             num_topk,\n                             rank,\n                             num_ranks,\n                             group,\n                             buffer,\n                             use_logfmt=args.use_logfmt,\n                             seed=seed)\n        for _ in range(20):\n            assert test_main(num_tokens,\n                             hidden,\n                             num_experts,\n                             num_topk,\n                             rank,\n                             num_ranks,\n                             group,\n                             buffer,\n                             use_logfmt=args.use_logfmt,\n                             seed=seed) == ref_hash, f'Error: seed={seed}'\n\n    # Destroy the buffer runtime and communication group\n    buffer.destroy()\n    dist.barrier()\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    # TODO: you may modify NUMA binding for less CPU overhead\n    # TODO: buggy with `num_tokens=512`\n    parser = argparse.ArgumentParser(description='Test low-latency EP kernels')\n    parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')\n    parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')\n    parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')\n    parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')\n    parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')\n    parser.add_argument('--allow-mnnvl', action=\"store_true\", help='Allow MNNVL for communication')\n    parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')\n    parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')\n    parser.add_argument(\"--pressure-test\", action='store_true', help='Whether to do pressure test')\n    parser.add_argument(\"--shrink-test\", action='store_true', help='Whether to simulate failure and test shrink mode')\n    args = parser.parse_args()\n\n    num_processes = args.num_processes\n    torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)\n"
  },
  {
    "path": "tests/utils.py",
    "content": "import inspect\nimport json\nimport tempfile\nfrom pathlib import Path\n\nimport numpy as np\nimport os\nimport sys\nimport torch\nimport torch.distributed as dist\nfrom typing import Optional, Union\n\n\ndef init_dist(local_rank: int, num_local_ranks: int):\n    # NOTES: you may rewrite this function with your own cluster settings\n    ip = os.getenv('MASTER_ADDR', '127.0.0.1')\n    port = int(os.getenv('MASTER_PORT', '8361'))\n    num_nodes = int(os.getenv('WORLD_SIZE', 1))\n    node_rank = int(os.getenv('RANK', 0))\n\n    sig = inspect.signature(dist.init_process_group)\n    params = {\n        'backend': 'nccl',\n        'init_method': f'tcp://{ip}:{port}',\n        'world_size': num_nodes * num_local_ranks,\n        'rank': node_rank * num_local_ranks + local_rank,\n    }\n    if 'device_id' in sig.parameters:\n        # noinspection PyTypeChecker\n        params['device_id'] = torch.device(f'cuda:{local_rank}')\n    dist.init_process_group(**params)\n    torch.set_default_dtype(torch.bfloat16)\n    torch.set_default_device('cuda')\n    torch.cuda.set_device(local_rank)\n\n    return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))\n\n\ndef calc_diff(x: torch.Tensor, y: torch.Tensor):\n    x, y = x.double() + 1, y.double() + 1\n    denominator = (x * x + y * y).sum()\n    sim = 2 * (x * y).sum() / denominator\n    return (1 - sim).item()\n\n\ndef align_up(x, y):\n    return (x + y - 1) // y * y\n\n\ndef per_token_cast_to_fp8(x: torch.Tensor):\n    assert x.dim() == 2\n    m, n = x.shape\n    aligned_n = align_up(n, 128)\n    x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)\n    x_padded_view = x_padded.view(m, -1, 128)\n    x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)\n    return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(\n        m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)\n\n\ndef per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):\n    if x_fp8.numel() == 0:\n        return x_fp8.to(torch.bfloat16)\n\n    assert x_fp8.dim() == 2\n    m, n = x_fp8.shape\n    aligned_n = align_up(n, 128)\n    x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0)\n    if x_scales.dtype == torch.int:\n        x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23\n        x_scales = x_scales.view(dtype=torch.float)\n    x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128)\n    x_scales = x_scales.view(x_fp8.size(0), -1, 1)\n    return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:, :n].contiguous()\n\n\ndef inplace_unique(x: torch.Tensor, num_slots: int):\n    assert x.dim() == 2\n    mask = x < 0\n    x_padded = x.masked_fill(mask, num_slots)\n    bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)\n    bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))\n    bin_count = bin_count[:, :num_slots]\n    sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)\n    sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)\n    sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values\n    x[:, :].fill_(-1)\n    valid_len = min(num_slots, x.size(1))\n    x[:, :valid_len] = sorted_bin_idx[:, :valid_len]\n\n\ndef create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):\n    num_tokens, num_experts = scores.shape\n    scores = scores.view(num_tokens, num_groups, -1)\n    mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)\n    mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)\n    return (scores * mask).view(num_tokens, num_experts)\n\n\ndef bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None):\n    # Flush L2 cache with 256 MB data\n    torch.cuda.synchronize()\n    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')\n\n    # Warmup\n    for _ in range(num_warmups):\n        fn()\n\n    # Flush L2\n    cache.zero_()\n\n    # Testing\n    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]\n    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]\n    for i in range(num_tests):\n        # Record\n        start_events[i].record()\n        fn()\n        end_events[i].record()\n        if post_fn is not None:\n            post_fn()\n    torch.cuda.synchronize()\n\n    times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]\n    return np.average(times), np.min(times), np.max(times)\n\n\nclass empty_suppress:\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *_):\n        pass\n\n\nclass suppress_stdout_stderr:\n\n    def __enter__(self):\n        self.outnull_file = open(os.devnull, 'w')\n        self.errnull_file = open(os.devnull, 'w')\n\n        self.old_stdout_fileno_undup = sys.stdout.fileno()\n        self.old_stderr_fileno_undup = sys.stderr.fileno()\n\n        self.old_stdout_fileno = os.dup(sys.stdout.fileno())\n        self.old_stderr_fileno = os.dup(sys.stderr.fileno())\n\n        self.old_stdout = sys.stdout\n        self.old_stderr = sys.stderr\n\n        os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)\n        os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)\n\n        sys.stdout = self.outnull_file\n        sys.stderr = self.errnull_file\n        return self\n\n    def __exit__(self, *_):\n        sys.stdout = self.old_stdout\n        sys.stderr = self.old_stderr\n\n        os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)\n        os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)\n\n        os.close(self.old_stdout_fileno)\n        os.close(self.old_stderr_fileno)\n\n        self.outnull_file.close()\n        self.errnull_file.close()\n\n\ndef bench_kineto(fn,\n                 kernel_names: Union[str, tuple],\n                 num_tests: int = 30,\n                 suppress_kineto_output: bool = False,\n                 trace_path: Optional[str] = None,\n                 barrier_comm_profiling: bool = False,\n                 num_kernels_per_period: int = 1):\n    # Profile\n    suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress\n    with suppress():\n        schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)\n        with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:\n            for _ in range(2):\n                # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead\n                if barrier_comm_profiling:\n                    lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')\n                    rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')\n                    lhs @ rhs\n                    dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))\n                for _ in range(num_tests):\n                    fn()\n                torch.cuda.synchronize()\n                prof.step()\n\n    # Parse the profiling table\n    assert isinstance(kernel_names, (str, tuple))\n    is_tuple = isinstance(kernel_names, tuple)\n    prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\\n')\n    kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names\n    assert all([isinstance(name, str) for name in kernel_names])\n    for name in kernel_names:\n        assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'\n\n    # Save chrome traces\n    if trace_path is not None:\n        prof.export_chrome_trace(trace_path)\n\n    # Return average kernel durations\n    units = {'ms': 1e3, 'us': 1e6}\n    kernel_durations = []\n    for name in kernel_names:\n        for line in prof_lines:\n            if name in line:\n                time_str = line.split()[-2]\n                for unit, scale in units.items():\n                    if unit in time_str:\n                        kernel_durations.append(float(time_str.replace(unit, '')) / scale)\n                        break\n                break\n\n    # Expand the kernels by periods\n    if num_kernels_per_period > 1:\n        with tempfile.NamedTemporaryFile(suffix='.json') as tmp:\n            prof.export_chrome_trace(tmp.name)\n            profile_data = json.loads(Path(tmp.name).read_text())\n\n        for i, kernel_name in enumerate(kernel_names):\n            events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']]\n            events = sorted(events, key=lambda event: event['ts'])\n            durations = [event['dur'] / 1e6 for event in events]\n            assert len(durations) % num_kernels_per_period == 0\n            num_kernel_patterns = len(durations) // num_kernels_per_period\n            kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)]\n\n    # Return execution durations\n    return kernel_durations if is_tuple else kernel_durations[0]\n\n\ndef hash_tensor(t: torch.Tensor):\n    return t.view(torch.int).sum().item()\n"
  },
  {
    "path": "third-party/README.md",
    "content": "# Install NVSHMEM\n\n## Important notices\n\n**This project is neither sponsored nor supported by NVIDIA.**\n\n**Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).**\n\n## Prerequisites\n\nHardware requirements:\n   - GPUs inside one node needs to be connected by NVLink\n   - GPUs across different nodes needs to be connected by RDMA devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/)\n   - InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)\n   - For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)\n\nSoftware requirements:\n   - NVSHMEM v3.3.9 or later\n\n## Installation procedure\n\n### 1. Install NVSHMEM binaries\n\nNVSHMEM 3.3.9 binaries are available in several formats:\n   - Tarballs for  [x86_64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-x86_64/libnvshmem-linux-x86_64-3.3.9_cuda12-archive.tar.xz) and [aarch64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-sbsa/libnvshmem-linux-sbsa-3.3.9_cuda12-archive.tar.xz)\n   - RPM and deb packages: instructions can be found on the [NVSHMEM installer page](https://developer.nvidia.com/nvshmem-downloads?target_os=Linux)\n   - Conda packages through conda-forge\n   - pip wheels through PyPI: `pip install nvidia-nvshmem-cu12`\nDeepEP is compatible with upstream NVSHMEM 3.3.9 and later.\n\n\n### 2. Enable NVSHMEM IBGDA support\n\nNVSHMEM Supports two modes with different requirements. Either of the following methods can be used to enable IBGDA support.\n\n#### 2.1 Configure NVIDIA driver\n\nThis configuration enables traditional IBGDA support.\n\nModify `/etc/modprobe.d/nvidia.conf`:\n\n```bash\noptions nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords=\"PeerMappingOverride=1;\"\n```\n\nUpdate kernel configuration:\n\n```bash\nsudo update-initramfs -u\nsudo reboot\n```\n\n#### 2.2 Install GDRCopy and load the gdrdrv kernel module\n\nThis configuration enables IBGDA through asynchronous post-send operations assisted by the CPU. More information about CPU-assisted IBGDA can be found in [this blog](https://developer.nvidia.com/blog/enhancing-application-portability-and-compatibility-across-new-platforms-using-nvidia-magnum-io-nvshmem-3-0/#cpu-assisted_infiniband_gpu_direct_async%C2%A0).\nIt comes with a small performance penalty, but can be used when modifying the driver regkeys is not an option.\n\nDownload GDRCopy\nGDRCopy is available as prebuilt deb and rpm packages [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/). or as source code on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy).\n\nInstall GDRCopy following the instructions on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy?tab=readme-ov-file#build-and-installation).\n\n## Post-installation configuration\n\nWhen not installing NVSHMEM from RPM or deb packages, set the following environment variables in your shell configuration:\n\n```bash\nexport NVSHMEM_DIR=/path/to/your/dir/to/install  # Use for DeepEP installation\nexport LD_LIBRARY_PATH=\"${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH\"\nexport PATH=\"${NVSHMEM_DIR}/bin:$PATH\"\n```\n\n## Verification\n\n```bash\nnvshmem-info -a # Should display details of nvshmem\n```\n"
  },
  {
    "path": "third-party/nvshmem.patch",
    "content": "From 9e6cc27cceb3130784e4ea7b61ea3171156365fd Mon Sep 17 00:00:00 2001\nFrom: Shangyan Zhou <sy.zhou@deepseek.com>\nDate: Fri, 20 Dec 2024 10:57:12 +0800\nSubject: [PATCH 1/4] Change QP creating order.\n\n---\n src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++-----\n 1 file changed, 8 insertions(+), 5 deletions(-)\n\ndiff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp\nindex ef325cd..286132e 100644\n--- a/src/modules/transport/ibgda/ibgda.cpp\n+++ b/src/modules/transport/ibgda/ibgda.cpp\n@@ -2936,17 +2936,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id\n         INFO(ibgda_state->log_level, \"Creating %d RC QPs\", device->rc.num_eps_per_pe);\n         for (int i = 0; i < num_rc_eps; ++i) {\n             // Do not create loopback to self\n-            if (i / device->rc.num_eps_per_pe == mype) {\n+            int dst_pe = (i + 1 + mype) % n_pes;\n+            int offset = i / n_pes;\n+            int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;\n+            if (dst_pe == mype) {\n                 continue;\n             }\n-            status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,\n+            status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,\n                                      NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);\n             NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,\n-                                  \"ibgda_create_dci failed on RC #%d.\", i);\n+                                  \"ibgda_create_dci failed on RC #%d.\", mapped_i);\n\n-            status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);\n+            status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);\n             NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,\n-                                  \"ibgda_get_rc_handle failed on RC #%d.\", i);\n+                                  \"ibgda_get_rc_handle failed on RC #%d.\", mapped_i);\n         }\n\n         if (num_rc_eps) {\n--\n2.25.1\n\n\nFrom b11d41e4f3727f2f6ccc00a8c852e59e2ee33c8a Mon Sep 17 00:00:00 2001\nFrom: Shangyan Zhou <sy.zhou@deepseek.com>\nDate: Fri, 10 Jan 2025 11:53:38 +0800\nSubject: [PATCH 2/4] Add recv queue and recv cq for rc qps.\n\nLet the ibgda rc qps use regular recv queue.\n\nAdd recv queue to ibgda dev qp.\n\nIBGDA create recv cq\n\nSetup recv cq.\n\nfix recv queue.\n\nRemove some useless idx.\n\nLonger recv queue.\n---\n .../nvshmem_common_ibgda.h                    | 19 +++++-\n src/modules/transport/ibgda/ibgda.cpp         | 65 ++++++++++++++++---\n 2 files changed, 71 insertions(+), 13 deletions(-)\n\ndiff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h\nindex 8b8a263..1be3dec 100644\n--- a/src/include/device_host_transport/nvshmem_common_ibgda.h\n+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h\n@@ -168,14 +168,17 @@ typedef struct {\n         uint64_t get_head;    // last wqe idx + 1 with a \"fetch\" operation (g, get, amo_fetch)\n         uint64_t get_tail;    // last wqe idx + 1 polled with cst; get_tail > get_head is possible\n     } tx_wq;\n+    struct {\n+        uint64_t resv_head;   // last reserved wqe idx + 1\n+    } rx_wq;\n     struct {\n         uint64_t head;\n         uint64_t tail;\n     } ibuf;\n     char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];\n } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;\n-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,\n-              \"ibgda_device_qp_management_v1 must be 96 bytes.\");\n+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,\n+              \"ibgda_device_qp_management_v1 must be 104 bytes.\");\n\n typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;\n\n@@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp {\n         // May point to mvars.prod_idx or internal prod_idx\n         uint64_t *prod_idx;\n     } tx_wq;\n+    struct {\n+        uint16_t nwqes;\n+        uint64_t tail;\n+        void *wqe;\n+        __be32 *dbrec;\n+        void *bf;\n+        nvshmemi_ibgda_device_cq_t *cq;\n+        // May point to mvars.prod_idx or internal prod_idx\n+        uint64_t *prod_idx;\n+    } rx_wq;\n     nvshmemi_ibgda_device_qp_management_v1 mvars;  // management variables\n } nvshmemi_ibgda_device_qp_v1;\n-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, \"ibgda_device_qp_v1 must be 184 bytes.\");\n+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, \"ibgda_device_qp_v1 must be 248 bytes.\");\n\n typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;\n\ndiff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp\nindex 286132e..e0b2d5c 100644\n--- a/src/modules/transport/ibgda/ibgda.cpp\n+++ b/src/modules/transport/ibgda/ibgda.cpp\n@@ -198,6 +198,7 @@ struct ibgda_ep {\n     off_t dbr_offset;\n\n     struct ibgda_cq *send_cq;\n+    struct ibgda_cq *recv_cq;\n     struct ibv_ah *ah;\n\n     uint32_t user_index;\n@@ -1538,7 +1539,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,\n\n     struct ibv_context *context = device->context;\n\n-    unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;\n+    // Each RC qp has one send CQ and one recv CQ.\n+    unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;\n\n     assert(ibgda_qp_depth > 0);\n     size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);\n@@ -1701,7 +1703,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,\n     }\n\n     // Allocate and map WQ buffer for all QPs.\n-    wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB;  // num_wqebb is always a power of 2\n+    // Todo: reduce the size of wq buffer.\n+    wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2;  // num_wqebb is always a power of 2\n     wq_buf_size = wq_buf_size_per_qp * num_eps;\n     status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);\n     NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, \"cannot allocate wq buf.\\n\");\n@@ -1882,8 +1885,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     int cqe_version = 0;\n\n     struct ibgda_cq *send_cq = NULL;\n+    struct ibgda_cq *recv_cq = NULL;\n\n     size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);\n+    size_t num_recv_wqe = ibgda_qp_depth;\n+    size_t recv_wqe_size = 16;\n\n     int status = 0;\n\n@@ -1911,6 +1917,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     status = ibgda_create_cq(&send_cq, device);\n     NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, \"ibgda_create_cq failed.\\n\");\n\n+    if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {\n+        status = ibgda_create_cq(&recv_cq, device);\n+        NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, \"ibgda_create_cq failed.\\n\");\n+    }\n+\n     ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));\n     NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,\n                             \"Unable to allocate mem for ep.\\n\");\n@@ -1939,12 +1950,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);\n     DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);\n     DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id);  // BF register\n-    DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE);        // Shared Receive Queue\n-    DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);\n     DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);\n-    DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);\n+    DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);\n     DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));\n-    DEVX_SET(qpc, qp_context, log_rq_size, 0);\n     DEVX_SET(qpc, qp_context, cs_req, 0);                                     // Disable CS Request\n     DEVX_SET(qpc, qp_context, cs_res, 0);                                     // Disable CS Response\n     DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);  // Enable dbr_umem_id\n@@ -1953,6 +1961,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id);  // DBR buffer\n     DEVX_SET(qpc, qp_context, user_index, qp_idx);\n     DEVX_SET(qpc, qp_context, page_offset, 0);\n+    if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){\n+        DEVX_SET(qpc, qp_context, rq_type, 0);        // Regular recv queue\n+        DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe\n+        DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B\n+    } else {\n+        DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE);        // Shared Receive Queue, DC must use this.\n+        DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);\n+        DEVX_SET(qpc, qp_context, log_rq_size, 0);\n+    }\n\n     ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));\n     NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,\n@@ -1962,9 +1979,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     ep->portid = portid;\n\n     ep->sq_cnt = num_wqebb;\n-    ep->sq_buf_offset = 0;\n+    ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;\n\n-    ep->rq_cnt = 0;\n+    ep->rq_cnt = num_recv_wqe;\n     ep->rq_buf_offset = 0;\n\n     ep->wq_mobject = device->qp_shared_object.wq_mobject;\n@@ -1978,6 +1995,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device\n     ep->uar_mobject = uar_mobject;\n\n     ep->send_cq = send_cq;\n+    ep->recv_cq = recv_cq;\n\n     ep->qp_type = qp_type;\n\n@@ -1989,6 +2007,7 @@ out:\n     if (status) {\n         if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);\n         if (send_cq) ibgda_destroy_cq(send_cq);\n+        if (recv_cq) ibgda_destroy_cq(recv_cq);\n         if (ep) free(ep);\n     }\n\n@@ -2287,6 +2306,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {\n         ibgda_destroy_cq(ep->send_cq);\n     }\n\n+    if (ep->recv_cq) {\n+        ibgda_destroy_cq(ep->recv_cq);\n+    }\n+\n     if (ep->ah) {\n         ftable.destroy_ah(ep->ah);\n     }\n@@ -2318,7 +2341,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda\n     dev_qp->qpn = ep->qpn;\n\n     assert(ep->wq_mobject->has_gpu_mapping);\n-    dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);\n+    dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);\n\n     if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {\n         assert(ep->dbr_mobject->has_gpu_mapping);\n@@ -2330,6 +2353,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda\n     }\n\n     dev_qp->tx_wq.nwqes = ep->sq_cnt;\n+    if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {\n+        dev_qp->rx_wq.nwqes = ep->rq_cnt;\n+        dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);\n+        dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);\n+        dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;\n+    }\n\n     ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;\n     ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);\n@@ -2379,6 +2408,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n     nvshmemi_ibgda_device_cq_t *cq_d = NULL;\n     nvshmemi_ibgda_device_cq_t *cq_h = NULL;\n\n+    nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;\n+    nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;\n+\n     uint8_t *qp_group_switches_d = NULL;\n\n     const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);\n@@ -2386,6 +2418,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n     const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);\n     const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);\n     const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);\n+    const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);\n\n     nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;\n     nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;\n@@ -2421,7 +2454,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n         num_dct_handles += device->dct.num_eps * n_pes;\n         num_dci_handles += device->dci.num_eps;\n         num_rc_handles += device->rc.num_eps_per_pe * n_pes;\n-        num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));\n+        num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);\n         num_shared_dci_handles += device->dci.num_shared_eps;\n     }\n     assert(num_dci_handles - num_shared_dci_handles >= 0);\n@@ -2456,6 +2489,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n     for (int i = 0; i < num_cq_handles; i++) {\n         nvshmemi_init_ibgda_device_cq(cq_h[i]);\n     }\n+\n+    recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));\n+    NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, \"recv_cq calloc err.\");\n+    nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);\n     /* allocate host memory for dct, rc, cq, dci end */\n\n     /* allocate device memory for dct, rc, cq, dci start */\n@@ -2559,6 +2596,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n                 }\n\n                 ++cq_idx;\n+\n+                rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];\n+\n+                ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);\n+                cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);\n+                cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;\n+                cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;\n+                ++cq_idx;\n             }\n         }\n     }\n--\n2.25.1\n\n\nFrom af479f9f23103d4a1579fae38676d6b3022df887 Mon Sep 17 00:00:00 2001\nFrom: Shangyan Zhou <sy.zhou@deepseek.com>\nDate: Sat, 8 Feb 2025 18:02:39 +0800\nSubject: [PATCH 3/4] Maintain recv queue's cons_idx.\n\n---\n src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++--\n src/modules/transport/ibgda/ibgda.cpp                    | 6 ++++--\n 2 files changed, 7 insertions(+), 4 deletions(-)\n\ndiff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h\nindex 1be3dec..ea1e284 100644\n--- a/src/include/device_host_transport/nvshmem_common_ibgda.h\n+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h\n@@ -170,6 +170,7 @@ typedef struct {\n     } tx_wq;\n     struct {\n         uint64_t resv_head;   // last reserved wqe idx + 1\n+        uint64_t cons_idx;    // polled wqe idx + 1 (consumer index + 1)\n     } rx_wq;\n     struct {\n         uint64_t head;\n@@ -177,7 +178,7 @@ typedef struct {\n     } ibuf;\n     char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];\n } __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;\n-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,\n-              \"ibgda_device_qp_management_v1 must be 104 bytes.\");\n+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,\n+              \"ibgda_device_qp_management_v1 must be 112 bytes.\");\n\n typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;\n@@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp {\n     } rx_wq;\n     nvshmemi_ibgda_device_qp_management_v1 mvars;  // management variables\n } nvshmemi_ibgda_device_qp_v1;\n-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, \"ibgda_device_qp_v1 must be 248 bytes.\");\n+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, \"ibgda_device_qp_v1 must be 256 bytes.\");\n\n typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;\n\ndiff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp\nindex e0b2d5c..bc339c5 100644\n--- a/src/modules/transport/ibgda/ibgda.cpp\n+++ b/src/modules/transport/ibgda/ibgda.cpp\n@@ -1067,7 +1067,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {\n         ibgda_host_mem_free(mobject);\n }\n\n-static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {\n+static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {\n     int status = 0;\n\n     struct ibgda_cq *gcq = NULL;\n@@ -1118,7 +1118,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)\n     cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);\n     DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);\n     DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);\n-    DEVX_SET(cqc, cq_context, cc, 0x1);  // Use collapsed CQ\n+    DEVX_SET(cqc, cq_context, cc, cc);  // Use collapsed CQ\n     DEVX_SET(cqc, cq_context, oi, 0x1);  // Allow overrun\n     DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);\n     DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));\n@@ -2419,6 +2419,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n     const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);\n     const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);\n     const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);\n+    const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);\n\n     nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;\n     nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;\n@@ -2601,6 +2602,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {\n\n                 ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);\n                 cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);\n+                cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);\n                 cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;\n                 cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;\n                 ++cq_idx;\n--\n2.25.1\n\n\nFrom e0ba3fa21b4b633b481c6684c3ad04f2670c8df4 Mon Sep 17 00:00:00 2001\nFrom: Shangyan Zhou <sy.zhou@deepseek.com>\nDate: Tue, 11 Feb 2025 11:00:57 +0800\nSubject: [PATCH 4/4] Init rx_wq counters.\n\n---\n src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++\n 1 file changed, 2 insertions(+)\n\ndiff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h\nindex ea1e284..e6640d6 100644\n--- a/src/include/device_host_transport/nvshmem_common_ibgda.h\n+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h\n@@ -46,6 +46,8 @@\n         qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \\\n         qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \\\n         qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \\\n+        qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \\\n+        qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                    \\\n         qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                         \\\n         qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID;                         \\\n     } while (0);\n--\n2.25.1\n\ndiff --git a/src/modules/transport/common/transport_ib_common.cpp b/src/modules/transport/common/transport_ib_common.cpp\nindex c89f408..f99018a 100644\n--- a/src/modules/transport/common/transport_ib_common.cpp\n+++ b/src/modules/transport/common/transport_ib_common.cpp\n@@ -26,6 +26,9 @@ int nvshmemt_ib_common_nv_peer_mem_available() {\n     if (access(\"/sys/kernel/mm/memory_peers/nvidia-peermem/version\", F_OK) == 0) {\n         return NVSHMEMX_SUCCESS;\n     }\n+    if (access(\"/sys/module/nvidia_peermem/version\", F_OK) == 0) {\n+        return NVSHMEMX_SUCCESS;\n+    }\n \n     return NVSHMEMX_ERROR_INTERNAL;\n }\n\n\nFrom 099f608fcd9a1d34c866ad75d0af5d02d2020374 Mon Sep 17 00:00:00 2001\nFrom: Kaichao You <youkaichao@gmail.com>\nDate: Tue, 10 Jun 2025 00:35:03 -0700\nSubject: [PATCH] remove gdrcopy dependency\n\n---\n src/modules/transport/ibgda/ibgda.cpp | 6 ++++++\n 1 file changed, 6 insertions(+)\n\ndiff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp\nindex ef325cd..16ee09c 100644\n--- a/src/modules/transport/ibgda/ibgda.cpp\n+++ b/src/modules/transport/ibgda/ibgda.cpp\n@@ -406,6 +406,7 @@ static size_t ibgda_get_host_page_size() {\n     return host_page_size;\n }\n\n+#ifdef NVSHMEM_USE_GDRCOPY\n int nvshmemt_ibgda_progress(nvshmem_transport_t t) {\n     nvshmemt_ibgda_state_t *ibgda_state = (nvshmemt_ibgda_state_t *)t->state;\n     int n_devs_selected = ibgda_state->n_devs_selected;\n@@ -459,6 +460,11 @@ int nvshmemt_ibgda_progress(nvshmem_transport_t t) {\n     }\n     return 0;\n }\n+#else\n+int nvshmemt_ibgda_progress(nvshmem_transport_t t) {\n+    return NVSHMEMX_ERROR_NOT_SUPPORTED;\n+}\n+#endif\n\n int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) {\n     NVSHMEMI_ERROR_PRINT(\"ibgda show info not implemented\");\n--\n2.34.1\n"
  }
]