Repository: deepseek-ai/FlashMLA Branch: main Commit: 47c35a712362 Files: 129 Total size: 1.1 MB Directory structure: gitextract_clsc5nbn/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── benchmark/ │ ├── bench_flash_mla.py │ └── visualize.py ├── csrc/ │ ├── api/ │ │ ├── api.cpp │ │ ├── common.h │ │ ├── dense_decode.h │ │ ├── dense_fwd.h │ │ ├── sparse_decode.h │ │ └── sparse_fwd.h │ ├── defines.h │ ├── kerutils/ │ │ └── include/ │ │ └── kerutils/ │ │ ├── common/ │ │ │ └── common.h │ │ ├── device/ │ │ │ ├── common.h │ │ │ ├── device.cuh │ │ │ ├── sm100/ │ │ │ │ ├── gemm.cuh │ │ │ │ ├── helpers.cuh │ │ │ │ ├── intrinsics.cuh │ │ │ │ └── tma_cta_group2_nosplit.cuh │ │ │ ├── sm80/ │ │ │ │ ├── helpers.cuh │ │ │ │ └── intrinsics.cuh │ │ │ └── sm90/ │ │ │ ├── helpers.cuh │ │ │ └── intrinsics.cuh │ │ ├── host/ │ │ │ └── host.h │ │ ├── kerutils.cuh │ │ └── supplemental/ │ │ └── torch_tensors.h │ ├── params.h │ ├── sm100/ │ │ ├── decode/ │ │ │ ├── head128/ │ │ │ │ └── README.md │ │ │ └── head64/ │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── model1.cu │ │ │ │ └── v32.cu │ │ │ ├── kernel.cuh │ │ │ └── kernel.h │ │ ├── helpers.h │ │ └── prefill/ │ │ ├── dense/ │ │ │ ├── collective/ │ │ │ │ ├── fmha_common.hpp │ │ │ │ ├── fmha_fusion.hpp │ │ │ │ ├── sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_load_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp │ │ │ │ └── sm100_fmha_mla_load_tma_warpspecialized.hpp │ │ │ ├── common/ │ │ │ │ ├── gather_tensor.hpp │ │ │ │ ├── helper.h │ │ │ │ ├── mask.cuh │ │ │ │ ├── pipeline_mla.hpp │ │ │ │ ├── pow_2.hpp │ │ │ │ └── utils.hpp │ │ │ ├── device/ │ │ │ │ ├── fmha.hpp │ │ │ │ └── fmha_device_bwd.hpp │ │ │ ├── fmha_cutlass_bwd_sm100.cu │ │ │ ├── fmha_cutlass_bwd_sm100.cuh │ │ │ ├── fmha_cutlass_fwd_sm100.cu │ │ │ ├── fmha_cutlass_fwd_sm100.cuh │ │ │ ├── interface.h │ │ │ └── kernel/ │ │ │ ├── fmha_causal_tile_scheduler.hpp │ │ │ ├── fmha_kernel_bwd_convert.hpp │ │ │ ├── fmha_kernel_bwd_sum_OdO.hpp │ │ │ ├── fmha_options.hpp │ │ │ ├── fmha_tile_scheduler.hpp │ │ │ ├── sm100_fmha_bwd_kernel_tma_warpspecialized.hpp │ │ │ ├── sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp │ │ │ └── sm100_fmha_fwd_kernel_tma_warpspecialized.hpp │ │ └── sparse/ │ │ ├── common_subroutine.h │ │ ├── fwd/ │ │ │ ├── head128/ │ │ │ │ ├── config.h │ │ │ │ ├── instantiations/ │ │ │ │ │ ├── phase1_k512.cu │ │ │ │ │ └── phase1_k576.cu │ │ │ │ ├── phase1.cuh │ │ │ │ └── phase1.h │ │ │ └── head64/ │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── phase1_k512.cu │ │ │ │ └── phase1_k576.cu │ │ │ ├── phase1.cuh │ │ │ └── phase1.h │ │ └── fwd_for_small_topk/ │ │ └── head128/ │ │ ├── config.h │ │ ├── instantiations/ │ │ │ ├── phase1_decode_k512.cu │ │ │ └── phase1_prefill_k512.cu │ │ ├── phase1.cuh │ │ └── phase1.h │ ├── sm90/ │ │ ├── decode/ │ │ │ ├── dense/ │ │ │ │ ├── config.h │ │ │ │ ├── instantiations/ │ │ │ │ │ ├── bf16.cu │ │ │ │ │ └── fp16.cu │ │ │ │ ├── splitkv_mla.cuh │ │ │ │ ├── splitkv_mla.h │ │ │ │ └── traits.h │ │ │ └── sparse_fp8/ │ │ │ ├── components/ │ │ │ │ ├── config.h │ │ │ │ ├── dequant.h │ │ │ │ └── helpers.h │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── model1_persistent_h128.cu │ │ │ │ ├── model1_persistent_h64.cu │ │ │ │ ├── v32_persistent_h128.cu │ │ │ │ └── v32_persistent_h64.cu │ │ │ ├── splitkv_mla.cuh │ │ │ └── splitkv_mla.h │ │ ├── helpers.h │ │ └── prefill/ │ │ └── sparse/ │ │ ├── config.h │ │ ├── fwd.cu │ │ ├── fwd.h │ │ ├── instantiations/ │ │ │ ├── phase1_k512.cu │ │ │ ├── phase1_k512_topklen.cu │ │ │ ├── phase1_k576.cu │ │ │ └── phase1_k576_topklen.cu │ │ ├── phase1.cuh │ │ └── phase1.h │ ├── smxx/ │ │ └── decode/ │ │ ├── combine/ │ │ │ ├── combine.cu │ │ │ └── combine.h │ │ └── get_decoding_sched_meta/ │ │ ├── get_decoding_sched_meta.cu │ │ └── get_decoding_sched_meta.h │ └── utils.h ├── docs/ │ ├── 20250422-new-kernel-deep-dive.md │ └── 20250929-hopper-fp8-sparse-deep-dive.md ├── flash_mla/ │ ├── __init__.py │ └── flash_mla_interface.py ├── setup.py └── tests/ ├── kernelkit/ │ ├── .gitignore │ ├── __init__.py │ ├── bench.py │ ├── compare.py │ ├── generate.py │ ├── precision.py │ └── utils.py ├── lib.py ├── quant.py ├── ref.py ├── test_flash_mla_dense_decoding.py ├── test_flash_mla_sparse_decoding.py ├── test_flash_mla_sparse_prefill.py └── test_fmha_sm100.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ build *.so *.egg-info/ __pycache__/ dist/ *perf.csv *.png /.vscode compile_commands.json .cache /dev /.clangd ================================================ FILE: .gitmodules ================================================ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 DeepSeek Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # FlashMLA ## Introduction FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: **Sparse Attention Kernels** *These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).* - Token-level sparse attention for the prefill stage - Token-level sparse attention for the decoding stage, with FP8 KV cache **Dense Attention Kernels** - Dense attention for the prefill stage - Dense attention for the decoding stage ## News - **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md). - **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 ## Performance #### Test & benchmark MLA decoding (Sparse & Dense): ```bash python tests/test_flash_mla_dense_decoding.py python tests/test_flash_mla_sparse_decoding.py ``` The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet). #### Test & benchmark MHA prefill (Dense): ```bash python tests/test_fmha_sm100.py ``` It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA. #### Test & benchmark MLA prefill (Sparse): ```bash python tests/test_flash_mla_sparse_prefill.py ``` It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. ## Requirements - SM90 / SM100 (See the support matrix below) - CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels) - PyTorch 2.0 and above Support matrix: | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | | Dense Decoding | SM90 | MQA | BF16 | | Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | | Dense Prefill | SM100 | MHA | | | Sparse Prefill | SM90 & SM100 | MQA | | [1]: For more details on using FP8 KV cache, see documents below. [2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp). ## Installation ```bash git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla cd flash-mla git submodule update --init --recursive pip install -v . ``` ## Usage ### MLA Decoding To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example: ```python from flash_mla import get_mla_metadata, flash_mla_with_kvcache tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk, ) for i in range(num_layers): ... o_i, lse_i = flash_mla_with_kvcache( q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, is_causal, is_fp8_kvcache, indices, ) ... ``` Where - `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1. - `h_kv` is the number of key-value heads. - `h_q` is the number of query heads. **FP8 KV Cache:** If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16. In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. - **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. - **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. See `tests/quant.py` for quantization and dequantization details. **Sparse Attention (`indices` tensor):** The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens. - **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`. - **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter. - **Invalid entries:** Set invalid indices to `-1`. **Return Values:** The kernel returns `(out, lse)`, where: - `out` is the attention result. - `lse` is the log-sum-exp value of the attention scores for each query head. See `tests/test_flash_mla_decoding.py` for a complete example. ### Sparse MLA Prefill For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters: - `q`: Query tensor of shape `[s_q, h_q, d_qk]` - `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]` - `indices`: Indices tensor of shape `[s_q, h_kv, topk]` - `sm_scale`: A scalar value **Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing. **Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`. **Return Values and Equivalent PyTorch Code:** The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations: ```python Q: [s_q, h_q, d_qk], bfloat16 kv: [s_kv, h_kv, d_qk], bfloat16 indices: [s_q, h_kv, topk], int32 kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1 indices = indices.squeeze(1) # [s_q, topk] focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk]. P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk] max_logits = P.max(dim=-1) # [s_q, h_q] lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2 S = exp2(P - lse) # [s_q, h_q, topk] out = S @ focused_kv # [s_q, h_q, d_qk] return (out, max_logits, lse) ``` See `tests/test_flash_mla_prefill.py` for a complete example. ### Dense MHA Prefill This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using: - `flash_attn_varlen_func` - `flash_attn_varlen_qkvpacked_func` - `flash_attn_varlen_kvpacked_func` The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example. ## Acknowledgement FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. ## Community Support ### MetaX For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). ### Hygon DCU For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). ### Intellifusion For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). ### Iluvatar Corex For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ### AMD Instinct For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html). The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py) ## Citation ```bibtex @misc{flashmla2025, title={FlashMLA: Efficient Multi-head Latent Attention Kernels}, author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, } ``` ================================================ FILE: benchmark/bench_flash_mla.py ================================================ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a import argparse import math import random import flashinfer import torch import triton import triton.language as tl # pip install flashinfer-python from flash_mla import flash_mla_with_kvcache, get_mla_metadata def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() key = key.float() value = value.float() key = key.repeat_interleave(h_q // h_kv, dim=0) value = value.repeat_interleave(h_q // h_kv, dim=0) attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) if is_causal: s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) return attn_weight @ value, lse @torch.inference_mode() def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") blocked_v = blocked_k[..., :dv] def ref_mla(): out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( q[i].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q, h_kv, is_causal=causal, ) out[i] = O.transpose(0, 1) lse[i] = LSE return out, lse out_torch, lse_torch = ref_mla() t = triton.testing.do_bench(ref_mla) return out_torch, lse_torch, t @torch.inference_mode() def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) def flash_mla(): return flash_mla_with_kvcache( q, blocked_k, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, ) out_flash, lse_flash = flash_mla() t = triton.testing.do_bench(flash_mla) return out_flash, lse_flash, t @torch.inference_mode() def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() kv_indptr = [0] kv_indices = [] for i in range(b): seq_len = cache_seqlens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_table[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) for seq_len in cache_seqlens[1:]: kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) q_indptr = torch.arange(0, b + 1).int() * s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3" ) mla_wrapper.plan( q_indptr, kv_indptr, kv_indices, cache_seqlens, h_q, dv, d-dv, block_size, causal, 1 / math.sqrt(d), q.dtype, blocked_k.dtype, ) def flash_infer(): output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) out_flash, lse_flash = flash_infer() t = triton.testing.do_bench(flash_infer) return out_flash, lse_flash, t @triton.jit def _mla_attn_kernel( Q_nope, Q_pe, Kv_c_cache, K_pe_cache, Req_to_tokens, B_seq_len, O, sm_scale, stride_q_nope_bs, stride_q_nope_h, stride_q_pe_bs, stride_q_pe_h, stride_kv_c_bs, stride_k_pe_bs, stride_req_to_tokens_bs, stride_o_b, stride_o_h, stride_o_s, BLOCK_H: tl.constexpr, BLOCK_N: tl.constexpr, NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM_CKV: tl.constexpr, HEAD_DIM_KPE: tl.constexpr, ): cur_batch = tl.program_id(1) cur_head_id = tl.program_id(0) split_kv_id = tl.program_id(2) cur_batch_seq_len = tl.load(B_seq_len + cur_batch) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] q_pe = tl.load(Q_pe + offs_q_pe) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) qk *= sm_scale qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) v_c = tl.trans(k_c) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) acc *= re_scale[:, None] acc += tl.dot(p.to(v_c.dtype), v_c) e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) def _mla_attn( q_nope, q_pe, kv_c_cache, k_pe_cache, attn_logits, req_to_tokens, b_seq_len, num_kv_splits, sm_scale, page_size, ): batch_size, head_num = q_nope.shape[0], q_nope.shape[1] head_dim_ckv = q_nope.shape[-1] head_dim_kpe = q_pe.shape[-1] BLOCK_H = 16 BLOCK_N = 64 grid = ( triton.cdiv(head_num, BLOCK_H), batch_size, num_kv_splits, ) _mla_attn_kernel[grid]( q_nope, q_pe, kv_c_cache, k_pe_cache, req_to_tokens, b_seq_len, attn_logits, sm_scale, # stride q_nope.stride(0), q_nope.stride(1), q_pe.stride(0), q_pe.stride(1), kv_c_cache.stride(-2), k_pe_cache.stride(-2), req_to_tokens.stride(0), attn_logits.stride(0), attn_logits.stride(1), attn_logits.stride(2), BLOCK_H=BLOCK_H, BLOCK_N=BLOCK_N, NUM_KV_SPLITS=num_kv_splits, PAGE_SIZE=page_size, HEAD_DIM_CKV=head_dim_ckv, HEAD_DIM_KPE=head_dim_kpe, ) @triton.jit def _mla_softmax_reducev_kernel( Logits, B_seq_len, O, stride_l_b, stride_l_h, stride_l_s, stride_o_b, stride_o_h, NUM_KV_SPLITS: tl.constexpr, HEAD_DIM_CKV: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) cur_batch_seq_len = tl.load(B_seq_len + cur_batch) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) e_sum = 0.0 e_max = -float("inf") acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) n_e_max = tl.maximum(logits_1, e_max) old_scale = tl.exp(e_max - n_e_max) acc *= old_scale exp_logic = tl.exp(logits_1 - n_e_max) acc += exp_logic * logits e_sum = e_sum * old_scale + exp_logic e_max = n_e_max tl.store( O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, acc / e_sum, ) def _mla_softmax_reducev( logits, o, b_seq_len, num_kv_splits, ): batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] grid = (batch_size, head_num) _mla_softmax_reducev_kernel[grid]( logits, b_seq_len, o, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), NUM_KV_SPLITS=num_kv_splits, HEAD_DIM_CKV=head_dim_ckv, num_warps=4, num_stages=2, ) def mla_decode_triton( q_nope, q_pe, kv_c_cache, k_pe_cache, o, req_to_tokens, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, ): assert num_kv_splits == attn_logits.shape[2] _mla_attn( q_nope, q_pe, kv_c_cache, k_pe_cache, attn_logits, req_to_tokens, b_seq_len, num_kv_splits, sm_scale, page_size, ) _mla_softmax_reducev( attn_logits, o, b_seq_len, num_kv_splits, ) @torch.inference_mode() def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() t = triton.testing.do_bench(flash_mla_triton) return out_flash, None, t FUNC_TABLE = { "torch": run_torch_mla, "flash_mla": run_flash_mla, "flash_infer": run_flash_infer, "flash_mla_triton": run_flash_mla_triton, } def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") device = torch.device("cuda:0") torch.set_default_dtype(dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) assert baseline in FUNC_TABLE assert target in FUNC_TABLE baseline_func = FUNC_TABLE[baseline] target_func = FUNC_TABLE[target] total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") q = torch.randn(b, s_q, h_q, d) block_size = 64 block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]: # flash_infer has a different lse return value # flash_mla_triton doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) assert target in FUNC_TABLE target_func = FUNC_TABLE[target] total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") q = torch.randn(b, s_q, h_q, d) block_size = 64 block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") return bytes / 10 ** 6 / perf_b available_targets = [ "torch", "flash_mla", "flash_infer", "flash_mla_triton", ] shape_configs = [ {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] ] def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--baseline", type=str, default="torch") parser.add_argument("--target", type=str, default="flash_mla") parser.add_argument("--all", action="store_true") parser.add_argument("--one", action="store_true") parser.add_argument("--compare", action="store_true") args = parser.parse_args() return args if __name__ == "__main__": args = get_args() benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: if args.all: for target in available_targets: perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') elif args.compare: perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') ================================================ FILE: benchmark/visualize.py ================================================ import argparse import matplotlib.pyplot as plt import pandas as pd def parse_args(): parser = argparse.ArgumentParser(description='Visualize benchmark results') parser.add_argument('--file', type=str, default='all_perf.csv', help='Path to the CSV file with benchmark results (default: all_perf.csv)') return parser.parse_args() args = parse_args() file_path = args.file df = pd.read_csv(file_path) names = df['name'].unique() for name in names: subset = df[df['name'] == name] plt.plot(subset['seqlen'], subset['bw'], label=name) plt.title('bandwidth') plt.xlabel('seqlen') plt.ylabel('bw (GB/s)') plt.legend() plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png') ================================================ FILE: csrc/api/api.cpp ================================================ #include #include "sparse_fwd.h" #include "sparse_decode.h" #include "dense_decode.h" #include "dense_fwd.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashMLA"; m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); } ================================================ FILE: csrc/api/common.h ================================================ #pragma once #include #include #include #include #include #include static constexpr float LOG_2_E = 1.44269504f; // Instantiation for tensor.data_ptr() template<> inline cutlass::bfloat16_t* at::TensorBase::data_ptr() const { return reinterpret_cast(this->data_ptr()); } // A struct that holds the architecture information of the current GPU. struct Arch { int major; int minor; int num_sms; cudaDeviceProp* device_prop; Arch() { device_prop = at::cuda::getCurrentDeviceProperties(); major = device_prop->major; minor = device_prop->minor; num_sms = device_prop->multiProcessorCount; } bool is_sm90a() const { return major == 9 && minor == 0; } bool is_sm100f() const { return major == 10; } }; // Convert int64_t stride to int32_t, with overflow check. inline int int64_stride_to_int(int64_t orig_stride) { if (orig_stride > std::numeric_limits::max()) { TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride); } return static_cast(orig_stride); } #define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \ [&] () { \ if (NUM_HEADS == 128) { \ static constexpr int CONSTEXPR_NAME = 128; \ return __VA_ARGS__(); \ } else if (NUM_HEADS == 64) { \ static constexpr int CONSTEXPR_NAME = 64; \ return __VA_ARGS__(); \ } else { \ TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \ } \ } (); #define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \ [&] () { \ if (HEAD_DIM == 576) { \ static constexpr int CONSTEXPR_NAME = 576; \ return __VA_ARGS__(); \ } else if (HEAD_DIM == 512) { \ static constexpr int CONSTEXPR_NAME = 512; \ return __VA_ARGS__(); \ } else { \ TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \ } \ } (); #define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \ [&] () { \ if (FLAG) { \ static constexpr bool CONSTEXPR_NAME = true; \ return __VA_ARGS__(); \ } else { \ static constexpr bool CONSTEXPR_NAME = false; \ return __VA_ARGS__(); \ } \ } (); #define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \ [&] () { \ if (MODEL_TYPE == ModelType::V32) { \ static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \ return __VA_ARGS__(); \ } else if (MODEL_TYPE == ModelType::MODEL1) { \ static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \ return __VA_ARGS__(); \ } else { \ TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \ } \ } (); // The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names. template constexpr auto get_static_enum_name(){ std::string_view name; #if __GNUC__ || __clang__ name = __PRETTY_FUNCTION__; std::size_t start = name.find('=') + 2; std::size_t end = name.size() - 1; name = std::string_view{ name.data() + start, end - start }; start = name.find("::"); #elif _MSC_VER name = __FUNCSIG__; std::size_t start = name.find('<') + 1; std::size_t end = name.rfind(">("); name = std::string_view{ name.data() + start, end - start }; start = name.rfind("::"); #endif return start == std::string_view::npos ? name : std::string_view { name.data() + start + 2, name.size() - start - 2 }; } template static constexpr std::size_t get_enum_max(){ constexpr T value = static_cast(N); if constexpr (get_static_enum_name().find(")") == std::string_view::npos) return get_enum_max(); else return N; } template requires std::is_enum_v static constexpr std::string get_dynamic_enum_name(T value){ constexpr std::size_t num = get_enum_max(); constexpr auto names = [](std::index_sequence){ return std::array{ get_static_enum_name(Is)>()... }; }(std::make_index_sequence{}); return (std::string)names[static_cast(value)]; } // A shortcut macro to declare supported features in an implementation class. #define DECLARE_SUPPORTED_FEATURES(...) \ protected: \ static constexpr FeatureT features[] = { __VA_ARGS__ }; \ constexpr inline std::span get_supported_features() const override { \ return features; \ } /* ImplBase - The base class for every implementation. Every implementation should inherit from this class and implement the pure virtual functions, including: - `run_`: The function that runs the implementation. - `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way. The dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`. */ template< typename RunArgT_, typename FeatureT_ > class ImplBase { protected: using RunArgT = RunArgT_; using FeatureT = FeatureT_; virtual inline void run_(const RunArgT ¶ms, const std::vector &required_features) = 0; constexpr virtual inline std::span get_supported_features() const = 0; virtual ~ImplBase() = default; public: inline bool check_if_all_features_are_supported(const std::vector &required_features) { for (const auto &required_feature : required_features) { bool is_supported = false; for (const auto &supported_feature : get_supported_features()) { if (required_feature == supported_feature) { is_supported = true; break; } } if (!is_supported) { return false; } } return true; } inline void check_if_all_features_are_supported_and_abort(const std::vector &required_features) { if (!check_if_all_features_are_supported(required_features)) { fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n"); fprintf(stderr, "Required features:\n"); for (const auto &f : required_features) { fprintf(stderr, " - %3d: %s\n", static_cast(f), get_dynamic_enum_name(f).c_str()); } fprintf(stderr, "\n"); fprintf(stderr, "Supported features:\n"); for (const auto &supported_feature : get_supported_features()) { fprintf(stderr, " - %3d: %s\n", static_cast(supported_feature), get_dynamic_enum_name(supported_feature).c_str()); } fprintf(stderr, "\n"); fprintf(stderr, "Features that are required but not supported:\n"); for (const auto &required_feature : required_features) { bool is_supported = false; for (const auto &supported_feature : get_supported_features()) { if (required_feature == supported_feature) { is_supported = true; break; } } if (!is_supported) { fprintf(stderr, " - %3d: %s\n", static_cast(required_feature), get_dynamic_enum_name(required_feature).c_str()); } } fprintf(stderr, "\n"); Arch cur_gpu_arch = Arch(); fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms); fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n"); TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details."); } } inline void run(const RunArgT ¶ms, const std::vector &required_features) { check_if_all_features_are_supported_and_abort(required_features); run_(params, required_features); } }; ================================================ FILE: csrc/api/dense_decode.h ================================================ #pragma once #include #include #include "common.h" #include "params.h" #include "sm90/decode/dense/splitkv_mla.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/combine/combine.h" static std::tuple, std::optional> dense_attn_decode_interface( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, std::optional &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4) std::optional &num_splits // batch_size + 1 ) { // Check arch Arch arch = Arch(); if (!arch.is_sm90a()) { TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); } // Check data types auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); // Check device KU_CHECK_DEVICE(q); KU_CHECK_DEVICE(kcache); KU_CHECK_DEVICE(seqlens_k); KU_CHECK_DEVICE(block_table); KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(num_splits); // Check layout TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); KU_CHECK_CONTIGUOUS(seqlens_k); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_q = sizes[2]; const int head_size_k = sizes[3]; TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported"); TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } const int num_q_heads_per_hk = num_heads_q / num_heads_k; const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; const int num_heads = num_heads_k; q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1); KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); KU_CHECK_SHAPE(seqlens_k, batch_size); KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int)); KU_CHECK_SHAPE(num_splits, batch_size+1); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts); at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); KU_CHECK_CONTIGUOUS(out); KU_CHECK_CONTIGUOUS(lse); if (!tile_scheduler_metadata.has_value()) { tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32)); num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32)); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); GetDecodeSchedMetaParams get_sched_meta_params = { batch_size, seqlen_q_ori, 64, 5, -1, -1, nullptr, nullptr, seqlens_k.data_ptr(), (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(), num_splits->data_ptr(), num_sm_parts, at::cuda::getCurrentCUDAStream().stream() }; smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); } else { KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(num_splits); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int)); KU_CHECK_SHAPE(num_splits, batch_size+1); } // Set the sizes DenseAttnDecodeParams params; params.b = batch_size; params.s_q = seqlen_q_ori; params.q_seq_per_hk = q_seq_per_hk; params.seqlens_k_ptr = seqlens_k.data_ptr(); params.h_q = num_heads_q; params.h_k = num_heads_k; params.num_blocks = num_blocks; params.q_head_per_hk = num_q_heads_per_hk; params.is_causal = is_causal; params.d = head_size_k; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = lse.data_ptr(); // All stride are in elements, not bytes. params.q_batch_stride = q.stride(0); params.k_batch_stride = kcache.stride(0); params.o_batch_stride = out.stride(0); params.q_row_stride = q.stride(1); params.k_row_stride = kcache.stride(1); params.o_row_stride = out.stride(2); params.q_head_stride = q.stride(2); params.k_head_stride = kcache.stride(2); params.o_head_stride = out.stride(1); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(); params.num_sm_parts = num_sm_parts; params.num_splits_ptr = num_splits->data_ptr(); const int total_num_splits = batch_size + params.num_sm_parts; at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); KU_CHECK_CONTIGUOUS(lse_accum); KU_CHECK_CONTIGUOUS(out_accum); params.total_num_splits = total_num_splits; params.softmax_lseaccum_ptr = lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); params.stream = at::cuda::getCurrentCUDAStream().stream(); if (q_dtype == torch::kBFloat16) { sm90::run_flash_splitkv_mla_kernel(params); } else if (q_dtype == torch::kHalf) { #ifdef FLASH_MLA_DISABLE_FP16 TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); #else sm90::run_flash_splitkv_mla_kernel(params); #endif } else { TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); } CombineParams combine_params = { batch_size, seqlen_q_ori, num_heads_q, head_size_v, params.softmax_lse_ptr, params.o_ptr, num_heads*q_seq_per_hk, num_heads_q, num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v, params.softmax_lseaccum_ptr, params.oaccum_ptr, num_heads*q_seq_per_hk, num_heads_q, num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v, params.tile_scheduler_metadata_ptr, params.num_splits_ptr, params.num_sm_parts, nullptr, at::cuda::getCurrentCUDAStream().stream() }; if (q_dtype == torch::kBFloat16) { smxx::decode::run_flash_mla_combine_kernel(combine_params); } else if (q_dtype == torch::kHalf) { #ifndef FLASH_MLA_DISABLE_FP16 smxx::decode::run_flash_mla_combine_kernel(combine_params); #endif } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2) .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) .reshape({batch_size, num_heads_q, seqlen_q_ori}); return {out, lse, tile_scheduler_metadata, num_splits}; } ================================================ FILE: csrc/api/dense_fwd.h ================================================ #pragma once #include "common.h" #include "sm100/prefill/dense/interface.h" ================================================ FILE: csrc/api/sparse_decode.h ================================================ #pragma once #include "common.h" #include "params.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "sm100/decode/head64/kernel.h" #include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/combine/combine.h" // Feature set of sparse decoding kernels enum class DecodeFeatures : int { HEAD_64, HEAD_128, HEAD_DIM_576, HEAD_DIM_512, V32_KVCACHE_FORMAT, MODEL1_KVCACHE_FORMAT, ATTN_SINK, TOPK_LENGTH, EXTRA_KVCACHE, EXTRA_TOPK_LENGTH }; struct DecodeImplMeta { int num_sm_parts; int fixed_overhead_num_blocks; int block_size_topk; }; class DecodeImplBase : public ImplBase< SparseAttnDecodeParams, DecodeFeatures > { public: virtual DecodeImplMeta get_meta(int h_q, int s_q) = 0; }; class Decode_Sm90_Impl : public DecodeImplBase { DECLARE_SUPPORTED_FEATURES( DecodeFeatures::HEAD_64, DecodeFeatures::HEAD_128, DecodeFeatures::HEAD_DIM_512, DecodeFeatures::HEAD_DIM_576, DecodeFeatures::V32_KVCACHE_FORMAT, DecodeFeatures::MODEL1_KVCACHE_FORMAT, DecodeFeatures::ATTN_SINK, DecodeFeatures::TOPK_LENGTH, DecodeFeatures::EXTRA_KVCACHE, DecodeFeatures::EXTRA_TOPK_LENGTH ) public: DecodeImplMeta get_meta(int h_q, int s_q) override { Arch arch = Arch(); return { std::max(arch.num_sms / s_q / (h_q/64), 1), 5, 64 }; } protected: void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() { sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel(params); }); }); } }; class Decode_Sm100_Head64_Impl : public DecodeImplBase { DECLARE_SUPPORTED_FEATURES( DecodeFeatures::HEAD_64, DecodeFeatures::HEAD_DIM_512, DecodeFeatures::HEAD_DIM_576, DecodeFeatures::V32_KVCACHE_FORMAT, DecodeFeatures::MODEL1_KVCACHE_FORMAT, DecodeFeatures::ATTN_SINK, DecodeFeatures::TOPK_LENGTH, DecodeFeatures::EXTRA_KVCACHE, DecodeFeatures::EXTRA_TOPK_LENGTH ) public: DecodeImplMeta get_meta(int h_q, int s_q) override { Arch arch = Arch(); return { std::max(arch.num_sms / s_q, 1), 5, 64 }; } protected: void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel(params); }); } }; // An implementation that calls the head64 kernel twice to process head128 // Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f class Decode_Sm100_Head64x2_Impl : public DecodeImplBase { DECLARE_SUPPORTED_FEATURES( DecodeFeatures::HEAD_128, DecodeFeatures::HEAD_DIM_512, DecodeFeatures::HEAD_DIM_576, DecodeFeatures::V32_KVCACHE_FORMAT, DecodeFeatures::MODEL1_KVCACHE_FORMAT, DecodeFeatures::ATTN_SINK, DecodeFeatures::TOPK_LENGTH, DecodeFeatures::EXTRA_KVCACHE, DecodeFeatures::EXTRA_TOPK_LENGTH ) public: DecodeImplMeta get_meta(int h_q, int s_q) override { Arch arch = Arch(); return { std::max(arch.num_sms / s_q, 1), 5, 64 }; } protected: void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) { SparseAttnDecodeParams cur_params = params; cur_params.q += start_head_idx * params.stride_q_h_q; if (cur_params.attn_sink) { cur_params.attn_sink += start_head_idx; } cur_params.lse += start_head_idx; cur_params.out += start_head_idx * params.stride_o_h_q; cur_params.lse_accum += start_head_idx; cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q; cur_params.h_q = 64; sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel(cur_params); } }); } }; class Decode_Sm100_Head128_Impl : public DecodeImplBase { DECLARE_SUPPORTED_FEATURES( DecodeFeatures::HEAD_128, DecodeFeatures::HEAD_DIM_512, DecodeFeatures::MODEL1_KVCACHE_FORMAT, DecodeFeatures::ATTN_SINK, DecodeFeatures::TOPK_LENGTH, DecodeFeatures::EXTRA_KVCACHE, DecodeFeatures::EXTRA_TOPK_LENGTH ) public: DecodeImplMeta get_meta(int h_q, int s_q) override { Arch arch = Arch(); return { std::max(arch.num_sms / s_q / 2, 1), 3, 64 }; } protected: void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel(params); } }; static std::tuple, std::optional> sparse_attn_decode_interface( const at::Tensor &q, // [b, s_q, h_q, d_qk] const at::Tensor &kv, // [num_blocks, page_block_size, h_k, d_qk] const at::Tensor &indices, // [b, s_q, topk] const std::optional &topk_length, // [b, s_q] const std::optional &attn_sink, // [h_q] std::optional &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4) std::optional &num_splits, // batch_size + 1 const std::optional &extra_kv, const std::optional &extra_indices, const std::optional &extra_topk_length, int d_v, float sm_scale ) { using bf16 = cutlass::bfloat16_t; // Check the architecture Arch arch = Arch(); KU_CHECK_NDIM(q, 4); KU_CHECK_NDIM(kv, 4); KU_CHECK_NDIM(indices, 3); int b = q.size(0); int s_q = q.size(1); int h_q = q.size(2); int d_qk = q.size(3); int num_blocks = kv.size(0); int page_block_size = kv.size(1); int h_kv = kv.size(2); int topk = indices.size(2); bool have_topk_length = topk_length.has_value(); bool have_extra_kcache = extra_kv.has_value(); bool have_extra_topk_length = extra_topk_length.has_value(); bool have_attn_sink = attn_sink.has_value(); int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0; if (have_extra_kcache) { extra_num_blocks = extra_kv->size(0); extra_page_block_size = extra_kv->size(1); } if (extra_indices.has_value()) { extra_topk = extra_indices->size(-1); } // metadata sanity check TORCH_CHECK(b > 0); TORCH_CHECK(s_q > 0); TORCH_CHECK(h_q > 0); TORCH_CHECK(h_kv == 1, "Currently only MQA (i.e. h_kv == 1) is supported for sparse decoding"); TORCH_CHECK(d_qk == 576 || d_qk == 512, "Only head_size_k == 576 or 512 is supported for sparse decoding"); TORCH_CHECK(d_v == 512, "Only head_size_v == 512 is supported for sparse decoding"); TORCH_CHECK(topk > 0); if (have_extra_kcache) { TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_kcache is provided for sparse attention"); } else { TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided"); TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided"); } // Check device KU_CHECK_DEVICE(q); KU_CHECK_DEVICE(kv); KU_CHECK_DEVICE(indices); KU_CHECK_DEVICE(topk_length); KU_CHECK_DEVICE(attn_sink); KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(num_splits); KU_CHECK_DEVICE(extra_kv); KU_CHECK_DEVICE(extra_indices); KU_CHECK_DEVICE(extra_topk_length); // Check data type KU_CHECK_DTYPE(q, torch::kBFloat16); TORCH_CHECK(kv.dtype() == torch::kFloat8_e4m3fn || kv.dtype() == torch::kInt8 || kv.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn, int8 or uint8"); if (extra_kv.has_value()) { TORCH_CHECK(extra_kv->dtype() == torch::kFloat8_e4m3fn || extra_kv->dtype() == torch::kInt8 || extra_kv->dtype() == torch::kUInt8, "extra k cache must have dtype fp8_e4m3fn, int8 or uint8"); } KU_CHECK_DTYPE(indices, torch::kInt32); KU_CHECK_DTYPE(topk_length, torch::kInt32); KU_CHECK_DTYPE(attn_sink, torch::kFloat32); KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_DTYPE(extra_indices, torch::kInt32); KU_CHECK_DTYPE(extra_topk_length, torch::kInt32); // Check layout KU_CHECK_LAST_DIM_CONTIGUOUS(q); KU_CHECK_LAST_DIM_CONTIGUOUS(kv); KU_CHECK_LAST_DIM_CONTIGUOUS(indices); KU_CHECK_CONTIGUOUS(topk_length); KU_CHECK_CONTIGUOUS(attn_sink); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv); KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices); KU_CHECK_CONTIGUOUS(extra_topk_length); // Check shape KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk); { int bytes_per_token; if (d_qk == 576 && d_v == 512) { // V3.2 style bytes_per_token = 512 + 64*2 + (512/128)*4; } else if (d_qk == 512 && d_v == 512) { // MODEL1 style bytes_per_token = 448 + 64*2 + (448/64)*1 + 1; } else { TORCH_CHECK(false, "Unsupported head sizes for is_fp8_kvcache == True"); } KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token); KU_CHECK_SHAPE(extra_kv, extra_num_blocks, extra_page_block_size, h_kv, bytes_per_token); TORCH_CHECK(kv.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for kv cache"); if (extra_kv.has_value()) { TORCH_CHECK(extra_kv->stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for extra kv cache"); } } KU_CHECK_SHAPE(indices, b, s_q, topk); KU_CHECK_SHAPE(topk_length, b); KU_CHECK_SHAPE(attn_sink, h_q); KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk); KU_CHECK_SHAPE(extra_topk_length, b); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts); at::Tensor lse = torch::empty({b, s_q, h_q}, opts.dtype(at::kFloat)); ModelType model_type; if (d_qk == 576) { model_type = ModelType::V32; } else if (d_qk == 512) { model_type = ModelType::MODEL1; } else { TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); } std::vector features; if (h_q == 64) { features.push_back(DecodeFeatures::HEAD_64); } else if (h_q == 128) { features.push_back(DecodeFeatures::HEAD_128); } else { TORCH_CHECK(false, "Unsupported h_q: ", h_q); } if (d_qk == 576) { features.push_back(DecodeFeatures::HEAD_DIM_576); } else if (d_qk == 512) { features.push_back(DecodeFeatures::HEAD_DIM_512); } else { TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); } if (model_type == ModelType::V32) { features.push_back(DecodeFeatures::V32_KVCACHE_FORMAT); } else if (model_type == ModelType::MODEL1) { features.push_back(DecodeFeatures::MODEL1_KVCACHE_FORMAT); } else { TORCH_CHECK(false, "Unsupported model type: ", (int)model_type); } if (have_attn_sink) { features.push_back(DecodeFeatures::ATTN_SINK); } if (have_topk_length) { features.push_back(DecodeFeatures::TOPK_LENGTH); } if (have_extra_kcache) { features.push_back(DecodeFeatures::EXTRA_KVCACHE); } if (have_extra_topk_length) { features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH); } DecodeImplBase* impl; if (arch.is_sm100f()) { if (h_q == 64) { impl = new Decode_Sm100_Head64_Impl(); } else if (h_q == 128) { if (d_qk == 576) { impl = new Decode_Sm100_Head64x2_Impl(); } else if (d_qk == 512) { impl = new Decode_Sm100_Head128_Impl(); } else { TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); } } else { TORCH_CHECK(false, "Unsupported h_q: ", h_q); } } else if (arch.is_sm90a()) { impl = new Decode_Sm90_Impl(); } else { TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd"); } DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q); SparseAttnDecodeParams params = { b, s_q, h_q, h_kv, d_qk, d_v, sm_scale, sm_scale * LOG_2_E, num_blocks, page_block_size, topk, model_type, (bf16*)q.data_ptr(), (bf16*)kv.data_ptr(), (int*)indices.data_ptr(), ku::get_optional_tensor_ptr(topk_length), ku::get_optional_tensor_ptr(attn_sink), (float*)lse.data_ptr(), (bf16*)out.data_ptr(), extra_num_blocks, extra_page_block_size, extra_topk, ku::get_optional_tensor_ptr(extra_kv), ku::get_optional_tensor_ptr(extra_indices), ku::get_optional_tensor_ptr(extra_topk_length), int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(q.stride(2)), int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), int64_stride_to_int(lse.stride(0)), int64_stride_to_int(lse.stride(1)), int64_stride_to_int(out.stride(0)), int64_stride_to_int(out.stride(1)), int64_stride_to_int(out.stride(2)), have_extra_kcache ? int64_stride_to_int(extra_kv->stride(0)) : 0, have_extra_kcache ? int64_stride_to_int(extra_kv->stride(1)) : 0, have_extra_kcache ? int64_stride_to_int(extra_indices->stride(0)) : 0, have_extra_kcache ? int64_stride_to_int(extra_indices->stride(1)) : 0, at::cuda::getCurrentCUDAStream().stream() }; // Get MLA metadata if necessary at::Tensor o_accum, lse_accum; if (!tile_scheduler_metadata.has_value()) { tile_scheduler_metadata = torch::empty({impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32)); num_splits = torch::empty({b+1}, opts.dtype(torch::kInt32)); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); GetDecodeSchedMetaParams get_sched_meta_params = { b, s_q, impl_meta.block_size_topk, impl_meta.fixed_overhead_num_blocks, topk, extra_topk, ku::get_optional_tensor_ptr(topk_length), ku::get_optional_tensor_ptr(extra_topk_length), nullptr, (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(), num_splits->data_ptr(), impl_meta.num_sm_parts, at::cuda::getCurrentCUDAStream().stream() }; smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); } // Stick the metadata pointers to `params` KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(num_splits); KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); KU_CHECK_CONTIGUOUS(num_splits); KU_CHECK_SHAPE(tile_scheduler_metadata, impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int)); KU_CHECK_SHAPE(num_splits, b+1); params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(); params.num_splits_ptr = num_splits->data_ptr(); params.num_sm_parts = impl_meta.num_sm_parts; // Allocate intermediate buffers for split-KV const int total_num_splits = b + impl_meta.num_sm_parts; lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat)); o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat)); KU_CHECK_CONTIGUOUS(lse_accum); KU_CHECK_CONTIGUOUS(o_accum); params.lse_accum = lse_accum.data_ptr(); params.o_accum = o_accum.data_ptr(); params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0)); params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1)); params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0)); params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1)); params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2)); impl->run(params, features); CombineParams combine_params = { b, s_q, h_q, d_v, params.lse, params.out, params.stride_lse_b, params.stride_lse_s_q, params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q, params.lse_accum, params.o_accum, params.stride_lse_accum_split, params.stride_lse_accum_s_q, params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q, params.tile_scheduler_metadata_ptr, params.num_splits_ptr, params.num_sm_parts, ku::get_optional_tensor_ptr(attn_sink), at::cuda::getCurrentCUDAStream().stream() }; smxx::decode::run_flash_mla_combine_kernel(combine_params); delete impl; return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits}; } ================================================ FILE: csrc/api/sparse_fwd.h ================================================ #pragma once #include "common.h" #include "params.h" #include "sm90/prefill/sparse/phase1.h" #include "sm100/prefill/sparse/fwd/head128/phase1.h" #include "sm100/prefill/sparse/fwd/head64/phase1.h" #include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h" enum class FwdFeatures : int { HEAD_64, HEAD_128, HEAD_DIM_576, HEAD_DIM_512, ATTN_SINK, SINK_LSE, TOPK_LENGTH }; class FwdImplBase : public ImplBase< SparseAttnFwdParams, FwdFeatures > {}; class Fwd_Sm90_Impl : public FwdImplBase { DECLARE_SUPPORTED_FEATURES( FwdFeatures::HEAD_64, FwdFeatures::HEAD_128, FwdFeatures::HEAD_DIM_512, FwdFeatures::HEAD_DIM_576, FwdFeatures::ATTN_SINK, FwdFeatures::SINK_LSE, FwdFeatures::TOPK_LENGTH ) protected: void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { sm90::fwd::run_fwd_phase1_kernel(params); }); }); } }; class Fwd_Sm100_Head64_Impl : public FwdImplBase { DECLARE_SUPPORTED_FEATURES( FwdFeatures::HEAD_64, FwdFeatures::HEAD_DIM_512, FwdFeatures::HEAD_DIM_576, FwdFeatures::ATTN_SINK, FwdFeatures::SINK_LSE, FwdFeatures::TOPK_LENGTH ) protected: void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { sm100::fwd::head64::run_fwd_phase1_kernel(params); }); } }; class Fwd_Sm100_Head128_Impl : public FwdImplBase { DECLARE_SUPPORTED_FEATURES( FwdFeatures::HEAD_128, FwdFeatures::HEAD_DIM_512, FwdFeatures::HEAD_DIM_576, FwdFeatures::ATTN_SINK, FwdFeatures::SINK_LSE, FwdFeatures::TOPK_LENGTH ) protected: void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { sm100::fwd::head128::run_fwd_phase1_kernel(params); }); } }; class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase { DECLARE_SUPPORTED_FEATURES( FwdFeatures::HEAD_128, FwdFeatures::HEAD_DIM_512, FwdFeatures::ATTN_SINK, FwdFeatures::SINK_LSE, FwdFeatures::TOPK_LENGTH ) protected: void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel(params); } }; static std::vector sparse_attn_prefill_interface( const at::Tensor &q, const at::Tensor &kv, const at::Tensor &indices, float sm_scale, int d_v, const std::optional &attn_sink, const std::optional &topk_length ) { using bf16 = cutlass::bfloat16_t; Arch arch = Arch(); bool is_sm90a = arch.is_sm90a(); bool is_sm100f = arch.is_sm100f(); TORCH_CHECK(is_sm90a || is_sm100f, "Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures."); KU_CHECK_NDIM(q, 3); KU_CHECK_NDIM(kv, 3); KU_CHECK_NDIM(indices, 3); KU_CHECK_NDIM(attn_sink, 1); KU_CHECK_NDIM(topk_length, 1); int s_q = q.size(0); int s_kv = kv.size(0); int h_q = q.size(1); int h_kv = kv.size(1); int d_qk = q.size(2); int topk = indices.size(2); bool have_topk_length = topk_length.has_value(); TORCH_CHECK(d_qk == 576 || d_qk == 512, "Invalid d_qk: ", d_qk); TORCH_CHECK(d_v == 512, "Invalid d_v", d_v); KU_CHECK_DEVICE(q); KU_CHECK_DEVICE(kv); KU_CHECK_DEVICE(indices); KU_CHECK_DEVICE(attn_sink); KU_CHECK_DEVICE(topk_length); KU_CHECK_DTYPE(q, torch::kBFloat16); KU_CHECK_DTYPE(kv, torch::kBFloat16); KU_CHECK_DTYPE(indices, torch::kInt32); KU_CHECK_DTYPE(attn_sink, torch::kFloat32); KU_CHECK_DTYPE(topk_length, torch::kInt32); KU_CHECK_SHAPE(q, s_q, h_q, d_qk); KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk); KU_CHECK_SHAPE(indices, s_q, h_kv, topk); KU_CHECK_SHAPE(attn_sink, h_q); KU_CHECK_SHAPE(topk_length, s_q); KU_CHECK_LAST_DIM_CONTIGUOUS(q); KU_CHECK_LAST_DIM_CONTIGUOUS(kv); KU_CHECK_LAST_DIM_CONTIGUOUS(indices); KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink); KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length); // Allocate results and buffers at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); at::Tensor out = torch::empty({s_q, h_q, d_v}, opts); at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); KU_CHECK_CONTIGUOUS(out); KU_CHECK_CONTIGUOUS(lse); KU_CHECK_CONTIGUOUS(max_logits); SparseAttnFwdParams params = { s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, sm_scale, sm_scale * LOG_2_E, (bf16*)q.data_ptr(), (bf16*)kv.data_ptr(), (int*)indices.data_ptr(), ku::get_optional_tensor_ptr(attn_sink), ku::get_optional_tensor_ptr(topk_length), int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), (bf16*)out.data_ptr(), (float*)max_logits.data_ptr(), (float*)lse.data_ptr(), arch.num_sms, at::cuda::getCurrentCUDAStream().stream() }; std::vector required_features; if (h_q == 64) { required_features.push_back(FwdFeatures::HEAD_64); } else if (h_q == 128) { required_features.push_back(FwdFeatures::HEAD_128); } else { TORCH_CHECK(false, "Unsupported h_q: ", h_q); } if (d_qk == 576) { required_features.push_back(FwdFeatures::HEAD_DIM_576); } else if (d_qk == 512) { required_features.push_back(FwdFeatures::HEAD_DIM_512); } else { TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); } if (attn_sink.has_value()) { required_features.push_back(FwdFeatures::ATTN_SINK); } if (have_topk_length) { required_features.push_back(FwdFeatures::TOPK_LENGTH); } if (is_sm90a) { Fwd_Sm90_Impl fwd_impl; fwd_impl.run(params, required_features); } else if (is_sm100f) { if (h_q == 64) { Fwd_Sm100_Head64_Impl fwd_impl; fwd_impl.run(params, required_features); } else if (h_q == 128) { Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl; Fwd_Sm100_Head128_Impl regular_impl; bool use_small_topk_impl = false; if ( (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) || !regular_impl.check_if_all_features_are_supported(required_features) ) { use_small_topk_impl = true; } if (use_small_topk_impl) { small_topk_impl.run(params, required_features); } else { regular_impl.run(params, required_features); } } else { TORCH_CHECK(false, "Unsupported h_q: ", h_q); } } else { TORCH_CHECK(false, "Unsupported architecture"); } return {out, max_logits, lse}; } ================================================ FILE: csrc/defines.h ================================================ #pragma once #include #include using bf16 = cutlass::bfloat16_t; using fp8 = cutlass::float_e4m3_t; using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; using cutlass::arch::fence_view_async_shared; using cutlass::arch::fence_barrier_init; using cutlass::arch::NamedBarrier; struct int32x8_t { int a0, a1, a2, a3, a4, a5, a6, a7; }; struct float8 { float2 a01, a23, a45, a67; }; struct bf16x8 { __nv_bfloat162 a01; __nv_bfloat162 a23; __nv_bfloat162 a45; __nv_bfloat162 a67; }; ================================================ FILE: csrc/kerutils/include/kerutils/common/common.h ================================================ #pragma once namespace kerutils {} #define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print("\n"); } namespace ku = kerutils; ================================================ FILE: csrc/kerutils/include/kerutils/device/common.h ================================================ /* Common data types and macros that are used across the kerutils library. */ #pragma once #include #include #include #include #include // For CUTE_DEVICE namespace kerutils { // Cache hints enum class CacheHint { EVICT_FIRST, EVICT_NORMAL, EVICT_LAST, EVICT_UNCHANGED, NO_ALLOCATE }; // Prefetch size enum class PrefetchSize { B64, B128, B256 }; using nvbf16 = __nv_bfloat16; using nvbf16x2 = __nv_bfloat162; using nve4m3 = __nv_fp8_e4m3; using nve4m3x2 = __nv_fp8x2_e4m3; using nve4m3x4 = __nv_fp8x4_e4m3; using bf16 = cutlass::bfloat16_t; using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define KERUTILS_ENABLE_SM80 #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) static_assert(false, "kerutils doesn't support SM architectures below SM80"); #endif #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #define KERUTILS_ENABLE_SM90 #endif #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000)) #define KERUTILS_ENABLE_SM90A #endif #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) #define KERUTILS_ENABLE_SM100 #endif #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) #define KERUTILS_ENABLE_SM100A #endif #if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) #define KERUTILS_ENABLE_SM80 #define KERUTILS_ENABLE_SM90 #define KERUTILS_ENABLE_SM90A #define KERUTILS_ENABLE_SM100 #define KERUTILS_ENABLE_SM100A #endif ================================================ FILE: csrc/kerutils/include/kerutils/device/device.cuh ================================================ #pragma once #include "kerutils/common/common.h" #include "common.h" #include "sm80/intrinsics.cuh" #include "sm80/helpers.cuh" #include "sm90/intrinsics.cuh" #include "sm90/helpers.cuh" #include "sm100/intrinsics.cuh" #include "sm100/helpers.cuh" #include "sm100/gemm.cuh" #include "sm100/tma_cta_group2_nosplit.cuh" ================================================ FILE: csrc/kerutils/include/kerutils/device/sm100/gemm.cuh ================================================ #pragma once #include #include namespace cute { // Extensions to CuTe // CuTe don't support UTCMMA with .ws, so we add it here // Besides, CuTe's UTCMMA has an `elect_one_sync()` inside which is really disgusting, so we have our own variant without `elect_one_sync()` here template struct SM100_MMA_F16BF16_WS_TS_NOELECT { static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint32_t const& tmem_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t" "}\n" : : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; // Actually this should be "duplicated", however, our great CuTe doesn't allow us to set it to "duplicated", so we just set it to NonInterleaved for a correct address calculation using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_ws_1sm; // Logical shape-K is always 256 bits; transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_1>; using ALayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint32_t tmem_a = raw_pointer_cast(A.data()); uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_WS_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; template struct SM100_MMA_F16BF16_WS_SS_NOELECT { static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_ws_1sm; // Logical shape-K is always 256bits, transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_1>; using ALayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint64_t desc_a = A[0]; uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; template struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT { static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 16 between 16 and 256."); static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint32_t const& tmem_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" "}\n" : : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); #endif } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; // Size of instructions' K extent is always 256 bits; convert to units of element constexpr static int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_2>; using ALayout = Layout,Int>>, Stride,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride,Stride< _1,Int>>>; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint64_t tmem_a = raw_pointer_cast(A.data()); uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; // SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() template struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT { static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 16 between 16 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); #endif } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; // Size of instructions's K extent is always 256bits, convert to units of element constexpr static int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_2>; using ALayout = Layout,Int>>, Stride,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride,Stride< _1,Int>>>; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint64_t desc_a = A[0]; uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; template struct SM100_MMA_F16BF16_TS_NOELECT { static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA."); static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_F16BF16_TS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ or a multiple of 16 between 16 and 256 for M=128."); static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_NOELECT A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint32_t const& tmem_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" "}\n" : : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_1sm; // Logical shape-K is always 256 bits; transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_1>; using ALayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint32_t tmem_a = raw_pointer_cast(A.data()); uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; template struct SM100_MMA_F16BF16_SS_NOELECT { static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA."); static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_F16BF16_SS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ or a multiple of 16 between 16 and 256 for M=128."); using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" "}\n" : : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } }; template struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_NOELECT supports 16bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_1sm; // Logical shape-K is always 256bits, transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_1>; using ALayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using BLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; using CLayout = Layout,Int>>, Stride<_0,Stride< _1,Int>>>; UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; template CUTE_HOST_DEVICE constexpr friend void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); uint64_t desc_a = A[0]; uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); SM100_MMA_F16BF16_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm100/helpers.cuh ================================================ #pragma once #include #include "kerutils/device/common.h" namespace kerutils { // Perform SS UTCMMA // sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment template< typename TiledMMA, typename TensorA, typename TensorB, typename TensorFragC > CUTE_DEVICE void utcmma_ss( TiledMMA &tiled_mma, TensorA sA, TensorB sB, TensorFragC tC_frag, bool clear_accum ) { using namespace cute; tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter auto sA_frag = thr_mma.partition_fragment_A(sA); auto sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(sA_frag) == size<2>(sB_frag)); static_assert(size<1>(sA_frag) == size<1>(tC_frag)); static_assert(size<1>(sB_frag) == size<2>(tC_frag)); CUTE_UNROLL for (int k = 0; k < size<2>(sA_frag); ++k) { cute::gemm( tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), tC_frag ); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } } // Perform TS UTCMMA // sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment template< typename TiledMMA, typename TensorA, typename TensorB, typename TensorFragC > CUTE_DEVICE void utcmma_ts( TiledMMA &tiled_mma, TensorA tA_frag, TensorB sB, TensorFragC tC_frag, bool clear_accum ) { using namespace cute; tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter auto sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(tA_frag) == size<2>(sB_frag)); CUTE_UNROLL for (int k = 0; k < size<2>(tA_frag); ++k) { cute::gemm( tiled_mma, tA_frag(_, _, k), sB_frag(_, _, k), tC_frag ); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } } template static constexpr auto make_umma_canonical_k_major_layout() { using namespace cute; using base_atom_type = \ std::conditional_t, std::conditional_t, std::conditional_t, std::conditional_t, void > > > >; static_assert(!std::is_same_v, "Invalid SWIZZLE value"); return coalesce(tile_to_shape( base_atom_type{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{}); } template static constexpr auto make_umma_canonical_mn_major_layout() { using namespace cute; using base_atom_type = \ std::conditional_t, std::conditional_t, std::conditional_t, std::conditional_t, void > > > >; static_assert(!std::is_same_v, "Invalid SWIZZLE value"); return coalesce(tile_to_shape( base_atom_type{}, Shape, Int>{}, Step<_2, _1>{} ), Shape<_1, _1>{}); } template auto make_umma_canonical_layout() { if constexpr (MAJOR == cute::UMMA::Major::K) { return make_umma_canonical_k_major_layout(); } else { return make_umma_canonical_mn_major_layout(); } } } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh ================================================ #pragma once #include "kerutils/device/common.h" namespace kerutils { // tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor) // Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr); asm volatile( "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" : : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), "r"(mbar_addr), "l"(cache_hint) : "memory" ); } // tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor) // Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios CUTE_DEVICE void tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { asm volatile( "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\n" : : "l"(desc_ptr), "r"(col_idx), "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), "l"(cache_hint) ); } // tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor) template CUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr); if constexpr (USE_CTA0_MBAR) { mbar_addr &= cute::Sm100MmaPeerBitMask; } asm volatile( "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" : : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), "r"(mbar_addr), "l"(cache_hint) : "memory" ); } // Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add) CUTE_DEVICE float2 float2_add(const float2 &a, const float2 &b) { float2 c; asm volatile( "add.f32x2 %0, %1, %2;\n" : "=l"(reinterpret_cast(c)) : "l"(reinterpret_cast(a)), "l"(reinterpret_cast(b)) ); return c; } // Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul) CUTE_DEVICE float2 float2_mul(const float2 &a, const float2 &b) { float2 c; asm volatile( "mul.f32x2 %0, %1, %2;\n" : "=l"(reinterpret_cast(c)) : "l"(reinterpret_cast(a)), "l"(reinterpret_cast(b))); return c; } // Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma) CUTE_DEVICE float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { // return a*b+c float2 d; asm volatile( "fma.rn.f32x2 %0, %1, %2, %3;\n" : "=l"(reinterpret_cast(d)) : "l"(reinterpret_cast(a)), "l"(reinterpret_cast(b)), "l"(reinterpret_cast(c))); return d; } // Vectorized negation for foat32 CUTE_DEVICE float2 float2_neg(const float2 &a) { float2 t = {-1.0f, -1.0f}; return float2_mul(a, t); } // st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk) CUTE_DEVICE void st_bulk(void* dst_ptr, int64_t size) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); asm volatile ( "st.bulk.weak.shared::cta [%0], %1, 0;\n" : : "r"(dst_addr), "l"(size) : "memory" ); } struct CUTE_ALIGNAS(16) CLCResponseObj { // An opaque 16B value char opaque[16]; }; struct CLCResult { int is_valid; int x, y, z; }; // Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel) CUTE_DEVICE void issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) { uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque); uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\n" : : "r"(response_addr), "r"(mbarrier_addr) ); } // Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel) CUTE_DEVICE void issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) { uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque); uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n" : : "r"(response_addr), "r"(mbarrier_addr) ); } // Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel) // In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions template CUTE_DEVICE CLCResult get_clc_query_response(CLCResponseObj &response_obj) { uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj); CLCResult result; #define EMIT_ASM(LD_MODIFIER) \ asm volatile( \ "{\n" \ ".reg .pred p1;\n\t" \ ".reg .b128 clc_result;\n\t" \ "ld" LD_MODIFIER ".shared.b128 clc_result, [%4];\n\t" \ "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" \ "selp.u32 %3, 1, 0, p1;\n\t" \ "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\n\t" \ "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\n\t" \ "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\n\t" \ "}\n" \ : "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.is_valid) \ : "r"(response_addr) \ : "memory" \ ); if constexpr (USE_LD_ACQUIRE) { EMIT_ASM(".acquire.cta"); } else { EMIT_ASM(""); } return result; } // LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld) // We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function // NC_STR should be either "" or ".nc" // L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" // L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last" // L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B" #define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \ { \ static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ static_assert(std::is_pointer_v || std::is_array_v, "`result` must be a pointer"); \ uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \ asm volatile( \ "ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v4.u64 {%0, %1, %2, %3}, [%4];\n" \ : "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]), \ "=l"(result_as_uint64_ptr[2]), "=l"(result_as_uint64_ptr[3]) \ : "l"(global_addr) \ ); \ } // STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st) // L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" // L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last" #define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \ { \ static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ static_assert(std::is_pointer_v || std::is_array_v, "`src` must be a pointer"); \ uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \ asm volatile( \ "st.global.L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".v4.u64 [%0], {%1, %2, %3, %4};\n" \ : \ : "l"(global_addr), "l"(src_as_uint64_ptr[0]), "l"(src_as_uint64_ptr[1]), \ "l"(src_as_uint64_ptr[2]), "l"(src_as_uint64_ptr[3]) \ ); \ } } namespace kerutils { // tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) CUTE_DEVICE void umma_arrive_noelect(transac_bar_t &bar) { uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n" : :"r"(bar_intptr) ); } // tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) CUTE_DEVICE void umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) { uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n" : :"r"(bar_intptr), "h"(cta_mask) ); } // tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) CUTE_DEVICE void umma_arrive_2x1SM_noelect(transac_bar_t &bar) { uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\n" : :"r"(bar_intptr) ); } // tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) CUTE_DEVICE void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) { uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); asm volatile( "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n" : :"r"(bar_intptr), "h"(cta_mask) ); } // tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence) __device__ __forceinline__ void tcgen05_before_thread_sync() { asm volatile("tcgen05.fence::before_thread_sync;"); } // tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence) __device__ __forceinline__ void tcgen05_after_thread_sync() { asm volatile("tcgen05.fence::after_thread_sync;"); } // Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) template __device__ __forceinline__ void tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) { uint32_t* data = (uint32_t*)data_; static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements"); // NOTE The following code crashes VSCode intellisense engine, so we disable it #ifndef __VSCODE_IDE__ [&](cute::index_sequence) { if constexpr (kNumElements == 1) { cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 2) { cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 4) { cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 8) { cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 16) { cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 32) { cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 64) { cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...); } else if constexpr (kNumElements == 128) { cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...); } }(cute::make_index_sequence{}); #endif } // Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) template __device__ __forceinline__ void tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) { uint32_t* data = (uint32_t*)data_; static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, "Invalid kNumReplications"); #ifndef __VSCODE_IDE__ [&](cute::index_sequence) { if constexpr (kNumReplications == 1) { cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 2) { cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 4) { cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 8) { cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 16) { cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 32) { cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 64) { cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...); } }(cute::make_index_sequence{}); #endif } // Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) template __device__ __forceinline__ void tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) { uint32_t* data = (uint32_t*)data_; static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, "Invalid kNumReplications"); #ifndef __VSCODE_IDE__ [&](cute::index_sequence) { if constexpr (kNumReplications == 1) { cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 2) { cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 4) { cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 8) { cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 16) { cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...); } else if constexpr (kNumReplications == 32) { cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...); } }(cute::make_index_sequence{}); #endif } // Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st) template __device__ __forceinline__ void tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) { uint32_t const* data = (uint32_t const*)data_; static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements"); #ifndef __VSCODE_IDE__ [&](cute::index_sequence) { if constexpr (kNumElements == 1) { cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 2) { cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 4) { cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 8) { cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 16) { cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 32) { cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 64) { cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start); } else if constexpr (kNumElements == 128) { cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start); } }(cute::make_index_sequence{}); #endif } } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh ================================================ #pragma once #include #include namespace cute { // Extensions to CuTe // CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100. //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_LOAD : Initiates a TMA copy from global memory to shared memory //////////////////////////////////////////////////////////////////////////////////////////////////// struct SM100_TMA_2SM_LOAD_1D_NOSPLIT { CUTE_HOST_DEVICE static void copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, [[maybe_unused]] void * smem_ptr, [[maybe_unused]] int32_t const& crd0) { #if defined(CUTE_ARCH_TMA_SM100_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); // Executed by both CTAs. Set peer bit to 0 so that the // transaction bytes will update CTA0's barrier. uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); #endif } }; struct SM100_TMA_2SM_LOAD_2D_NOSPLIT { CUTE_HOST_DEVICE static void copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, [[maybe_unused]] void * smem_ptr, [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) { #if defined(CUTE_ARCH_TMA_SM100_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); // Executed by both CTAs. Set peer bit to 0 so that the // transaction bytes will update CTA0's barrier. uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); #endif } }; struct SM100_TMA_2SM_LOAD_3D_NOSPLIT { CUTE_HOST_DEVICE static void copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, [[maybe_unused]] void * smem_ptr, [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { #if defined(CUTE_ARCH_TMA_SM100_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); // Executed by both CTAs. Set peer bit to 0 so that the // transaction bytes will update CTA0's barrier. uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); #endif } }; struct SM100_TMA_2SM_LOAD_4D_NOSPLIT { CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { #if defined(CUTE_ARCH_TMA_SM100_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); // Executed by both CTAs. Set peer bit to 0 so that the // transaction bytes will update CTA0's barrier. uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); #endif } }; struct SM100_TMA_2SM_LOAD_5D_NOSPLIT { CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { #if defined(CUTE_ARCH_TMA_SM100_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); // Executed by both CTAs. Set peer bit to 0 so that the // transaction bytes will update CTA0's barrier. uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); #endif } }; struct SM100_TMA_2SM_LOAD_NOSPLIT { CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); } CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); } CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); } CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); } CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); } using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; }; struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {}; // The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar // Use .with(tma_mbar) to construct an executable version template struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; // SM100_TMA_2SM_LOAD_NOSPLIT arguments TmaDescriptor tma_desc_; using AuxParams = AuxParams_; AuxParams aux_params_; // Return TmaDescriptor/TensorMap CUTE_HOST_DEVICE constexpr TmaDescriptor const* get_tma_descriptor() const { return &tma_desc_; } // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits with( uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; } // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) CUTE_HOST_DEVICE constexpr Copy_Traits with( TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; } // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr auto get_tma_tensor(GShape const& g_shape) const { static_assert(is_congruent::value); return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); } // Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with() template CUTE_HOST_DEVICE friend constexpr void copy_unpack(Copy_Traits const& traits, Tensor const& src, Tensor & dst) = delete; }; // The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar template struct Copy_Traits : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; // SM100_TMA_2SM_LOAD_NOSPLIT arguments tuple< TmaDescriptor const*, uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) : opargs_(desc, mbar, cache) {} }; } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm80/helpers.cuh ================================================ #pragma once #include "kerutils/device/common.h" #include "kerutils/device/sm80/intrinsics.cuh" namespace kerutils { // Retrieve the value of `%smid` and check its range CUTE_DEVICE uint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) { uint32_t sm_id = get_sm_id(); if (!(sm_id < num_physical_sms)) { trap(); } return sm_id; } #ifndef KU_TRAP_ONLY_DEVICE_ASSERT #define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) \ asm("trap;"); \ } while (0) #endif // Construct a `float2` from a single `float` by duplicating the value CUTE_DEVICE float2 float2float2(const float &x) { return float2 {x, x}; } CUTE_DEVICE void st_shared(void* ptr, __int128_t val) { asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); } CUTE_DEVICE void st_shared(void* ptr, float4 val) { st_shared(ptr, *(__int128_t*)&val); } CUTE_DEVICE __int128_t ld_shared(void* ptr) { __int128_t val; asm volatile("ld.shared.b128 %0, [%1];" : "=q"(val) : "l"(__cvta_generic_to_shared(ptr))); return val; } CUTE_DEVICE float4 ld_shared_float4(void* ptr) { __int128_t temp = ld_shared(ptr); return *(float4*)&temp; } } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh ================================================ #pragma once #include "kerutils/device/common.h" namespace kerutils { // cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async) template CUTE_DEVICE void cp_async_cacheglobal(const void* src, void* dst, bool pred=true) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); if constexpr (PREFETCH_SIZE == PrefetchSize::B64) { asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\n" :: "r"(dst_addr), "l"(src), "r"(pred?16:0)); } else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) { asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\n" :: "r"(dst_addr), "l"(src), "r"(pred?16:0)); } else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) { asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\n" :: "r"(dst_addr), "l"(src), "r"(pred?16:0)); } else { static_assert(PREFETCH_SIZE == PrefetchSize::B64 || PREFETCH_SIZE == PrefetchSize::B128 || PREFETCH_SIZE == PrefetchSize::B256, "Unsupported prefetch size for cp_async_cacheglobal."); } } // Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy) template CUTE_DEVICE int64_t create_fraction_based_cache_policy(float fraction = 1.0f) { int64_t result; #define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \ asm volatile( \ "createpolicy.fractional.L2::" PRIMARY_PRIORITY_STR ".L2::" SECONDARY_PRIORITY_STR ".b64 %0, %1;\n" \ : "=l"(result) \ : "f"(fraction) \ ); #define EMIT2(PRIMARY_PRIORITY_STR) \ { \ if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \ EMIT(PRIMARY_PRIORITY_STR, "evict_first") \ } else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \ EMIT(PRIMARY_PRIORITY_STR, "evict_unchanged") \ } else { \ static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \ SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \ "Unsupported secondary cache hint for create_fraction_based_cache_policy."); \ } \ } if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) { EMIT2("evict_first"); } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) { EMIT2("evict_normal"); } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) { EMIT2("evict_last"); } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { EMIT2("evict_unchanged"); } else { static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST || PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL || PRIMARY_PRIORITY == CacheHint::EVICT_LAST || PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED, "Unsupported primary cache hint for create_fraction_based_cache_policy."); } #undef EMIT #undef EMIT2 return result; } // Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f)) // The same as cute::TMA::CacheHintSmXX template CUTE_DEVICE constexpr int64_t create_simple_cache_policy() { if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) { return 0x12F0000000000000; // Result of createpolicy.fractional.L2::evict_first.b64 } else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) { return 0x1000000000000000; // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?) } else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) { return 0x14F0000000000000; // Result of createpolicy.fractional.L2::evict_last.b64 } else { static_assert(CACHE_HINT == CacheHint::EVICT_FIRST || CACHE_HINT == CacheHint::EVICT_NORMAL || CACHE_HINT == CacheHint::EVICT_LAST, "Unsupported cache hint for create_simple_cache_policy."); } } // AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red) CUTE_DEVICE void atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) { asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.eq.u32 p, %3, 1;\n\t" "@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \n\t" "}" : : "f"(data), "l"((int64_t)global_addr), "l"(cache_policy), "r"(pred) ); } // Get the id of the current SM // About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`. // Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT] CUTE_DEVICE uint32_t get_sm_id() { uint32_t ret; asm volatile("mov.u32 %0, %%smid;\n" : "=r"(ret)); return ret; } // trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap) CUTE_DEVICE void trap() { asm volatile("trap;\n"); } // LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld) // We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function // NC_STR should be either "" or ".nc" // L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" // L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B" // L2 cache hint is not supported since it's only supported for LDG.256 #define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \ { \ static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ static_assert(std::is_pointer_v || std::is_array_v, "`result` must be a pointer"); \ uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \ asm volatile( \ "ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v2.u64 {%0, %1}, [%2];\n" \ : "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]) \ : "l"(global_addr) \ ); \ } } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm90/helpers.cuh ================================================ #pragma once #include #include "kerutils/device/common.h" namespace kerutils { template< typename TMA, typename Tensor0, typename Tensor1 > CUTE_DEVICE void launch_tma_copy( const TMA &tma_copy, Tensor0 src, Tensor1 dst, transac_bar_t &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL ) { auto thr_tma = tma_copy.get_slice(cute::_0{}); cute::copy( tma_copy.with(reinterpret_cast(bar), 0, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); } // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a CUTE_DEVICE int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a CUTE_DEVICE int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); return col_idx; } template CUTE_DEVICE void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) { using namespace cute; constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); warpgroup_arrive(); tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } if constexpr (commit) { warpgroup_commit_batch(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } template CUTE_DEVICE void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sA_frag = thr_mma.partition_fragment_A(sA); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(sA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(sA_frag); ++k) { cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); } template CUTE_DEVICE void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(rA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(const_cast(rA_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(rA_frag); ++k) { cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); warpgroup_fence_operand(const_cast(rA_frag)); } } ================================================ FILE: csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh ================================================ #pragma once #include "kerutils/device/common.h" namespace kerutils { // st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async) template CUTE_DEVICE static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) { static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async."); long2 data_long2 = *reinterpret_cast(&data); uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); asm volatile ( "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" : : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) ); } static constexpr int PEER_ADDR_MASK = 16777216; // Given an address in the current CTA, return the corresponding address in the peer CTA template CUTE_DEVICE T* get_peer_addr(const T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); } // Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0) template CUTE_DEVICE T* get_cta0_addr(const T* p) { constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF; return (T*)((int64_t)(p) & CTA0_ADDR_MASK); } // TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk) CUTE_DEVICE void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr); asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" : : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) : "memory"); } // Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) CUTE_DEVICE void barrier_cluster_arrive_release() { asm volatile("barrier.cluster.arrive.release;" : : : "memory"); } // Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) CUTE_DEVICE void barrier_cluster_arrive_relaxed() { asm volatile("barrier.cluster.arrive.relaxed;" : : :); } // Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) CUTE_DEVICE void barrier_cluster_wait_acquire() { asm volatile("barrier.cluster.wait.acquire;" : : : "memory"); } // mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive) CUTE_DEVICE void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) { uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar); asm volatile( "{\n\t" "mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t" "}" : : "r"(smem_addr)); } // AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red) CUTE_DEVICE void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) { asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.eq.u32 p, %6, 1;\n\t" "@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" "}" : : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), "l"((int64_t)global_addr), "l"(cache_policy), "r"(pred) ); } // cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk) CUTE_DEVICE void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) { uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr); uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); asm volatile( "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" : : "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr) ); } } ================================================ FILE: csrc/kerutils/include/kerutils/host/host.h ================================================ #pragma once #include #include #include #include #include #include #include #include "kerutils/common/common.h" namespace kerutils { class KUException final : public std::exception { std::string message = {}; public: template explicit KUException(const char *name, const char* file, const int line, Args&&... args) { std::ostringstream oss; oss << name << " error (" << file << ":" << line << "): "; (oss << ... << args); message = oss.str(); } const char *what() const noexcept override { return message.c_str(); } }; #define THROW_KU_EXCEPTION(name, ...) \ throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__) #define KU_CUDA_CHECK(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ THROW_KU_EXCEPTION("CUDA", "CUDA error: ", cudaGetErrorString(status_)); \ } \ } while(0) #define KU_CUTLASS_CHECK(call) \ do { \ cutlass::Status status_ = call; \ if (status_ != cutlass::Status::kSuccess) { \ fprintf(stderr, "CUTLASS error (%s:%d): %d\n", __FILE__, __LINE__, static_cast(status_)); \ THROW_KU_EXCEPTION("CUTLASS", "CUTLASS error: ", static_cast(status_)); \ } \ } while(0) // This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not. #define KU_ASSERT(cond, ...) \ do { \ if (not (cond)) { \ fprintf(stderr, "Assertion `%s` failed (%s:%d): ", #cond, __FILE__, __LINE__); \ if constexpr (sizeof(#__VA_ARGS__) > 1) { \ fprintf(stderr, ", " __VA_ARGS__); \ } \ fprintf(stderr, "\n"); \ THROW_KU_EXCEPTION("Assertion", "Assertion `", #cond, "` failed."); \ } \ } while(0) #define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError()) template inline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) { return (a + b - 1) / b; } template inline __host__ __device__ constexpr T ceil(const T &a, const T &b) { return (a + b - 1) / b * b; } // A wrapper for make_tensor_map static inline CUtensorMap make_tensor_map( const std::vector &size, const std::vector &strides, // PAY ATTENTION: In BYTES const std::vector &box_size, void* global_ptr, CUtensorMapDataType data_type, CUtensorMapSwizzle swizzle_mode, CUtensorMapL2promotion l2_promotion, CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, const std::vector &element_strides_ = {} ) { int dim = size.size(); KU_ASSERT(dim >= 1); std::vector element_strides; if (element_strides_.empty()) { for (int i = 0; i < dim; ++i) element_strides.push_back(1); } else { element_strides = element_strides_; } KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim); CUtensorMap result; CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &result, data_type, dim, global_ptr, size.data(), strides.data(), box_size.data(), element_strides.data(), interleave_mode, swizzle_mode, l2_promotion, oob_fill ); if (ret_code != CUresult::CUDA_SUCCESS) { auto print_vector = [&](auto t, const char* fmt, const char end='\n') { for (auto elem : t) { printf(fmt, elem); } printf("%c", end); }; fprintf(stderr, "Failed to create tensormap\n"); fprintf(stderr, "Dim: %d\n", dim); printf("size: "); print_vector(size, "%lu "); printf("strides: "); print_vector(strides, "%lu "); printf("box_size: "); print_vector(box_size, "%u "); printf("element_strides: "); print_vector(element_strides, "%u "); printf("global ptr: 0x%lx\n", (int64_t)global_ptr); printf("data_type: %d\n", (int)data_type); printf("swizzle_mode: %d\n", (int)swizzle_mode); printf("l2_promotion: %d\n", (int)l2_promotion); printf("interleave_mode: %d\n", (int)interleave_mode); printf("oob_fill: %d\n", (int)oob_fill); KU_ASSERT(false); } return result; } // Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size template static inline std::vector make_stride_helper(const std::vector &strides_in_elems, size_t elem_size) { std::vector res; for (auto stride : strides_in_elems) { res.push_back(((uint64_t)stride) * elem_size); } return res; } } ================================================ FILE: csrc/kerutils/include/kerutils/kerutils.cuh ================================================ #pragma once #include "host/host.h" #include "device/device.cuh" ================================================ FILE: csrc/kerutils/include/kerutils/supplemental/torch_tensors.h ================================================ #pragma once #include #include #include "kerutils/common/common.h" namespace kerutils { // Check whether the given tensor or optional tensor satisfies the given condition // If tensor_or_opt is a tensor, check_fn is applied directly // If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value template static inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function& check_fn) { if constexpr (std::is_same::value) { return check_fn(tensor_or_opt); } else { if (tensor_or_opt.has_value()) { return check_fn(tensor_or_opt.value()); } else { return true; } } } // Get the pointer of the given tensor // Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise template static inline PtrT* get_tensor_ptr(const at::Tensor& tensor) { if (tensor.has_storage()) { return (PtrT*)tensor.data_ptr(); } else { return nullptr; } } // Get the pointer of the given tensor or optional tensor // Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise template static inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) { if constexpr (std::is_same::value) { return get_tensor_ptr(tensor_or_opt); } else { if (tensor_or_opt.has_value()) { return get_tensor_ptr(*tensor_or_opt); } else { return nullptr; } } } } // Check whether the given tensor (or optional) is on cuda #define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor " must be on CUDA") // Check whether the given tensor (or optional) has the given number of dimensions #define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor " must have " #ndim " dimensions") // Check whether the given tensor (or optional) has the given shape #define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor " must have shape (" #__VA_ARGS__ ")") // Check whether the given tensor (or optional) is contiguous #define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor " must be contiguous") // Check whether the last dimention of the given tensor (or optional) #define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor " must have contiguous last dimension") // Check whether the given tensor (or optional) has the specified dtype #define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor " must have dtype " #target_dtype) ================================================ FILE: csrc/params.h ================================================ #pragma once #include "cutlass/bfloat16.h" enum class ModelType { V32, MODEL1 }; struct __align__(4*8) DecodingSchedMeta { int begin_req_idx, end_req_idx; // Both inclusive int begin_block_idx, end_block_idx; // Inclusive, exclusive int begin_split_idx; int is_first_req_splitted, is_last_req_splitted; int _pad[1]; }; static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta); struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams using index_t = int64_t; int b; // batch size int s_q; int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q int d, d_v; // K/V dimension int h_q, h_k; // The number of Q/K heads int num_blocks; // Number of blocks in total int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ o_ptr; float *__restrict__ softmax_lse_ptr; index_t q_batch_stride; index_t k_batch_stride; index_t o_batch_stride; index_t q_row_stride; index_t k_row_stride; index_t o_row_stride; index_t q_head_stride; index_t k_head_stride; index_t o_head_stride; int *__restrict__ block_table; index_t block_table_batch_stride; int page_block_size; int *__restrict__ seqlens_k_ptr; DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; int num_sm_parts; int *__restrict__ num_splits_ptr; int total_num_splits; float *__restrict__ softmax_lseaccum_ptr; float *__restrict__ oaccum_ptr; cudaStream_t stream; }; struct SparseAttnDecodeParams { int b, s_q; int h_q, h_kv; int d_qk, d_v; float sm_scale, sm_scale_div_log2; int num_blocks, page_block_size, topk; ModelType model_type; cutlass::bfloat16_t* __restrict__ q; // [b, s_q, h_q, d_qk] cutlass::bfloat16_t* __restrict__ kv; // [num_blocks, page_block_size, d_qk] int* __restrict__ indices; // [b, s_q, topk] int* __restrict__ topk_length; // [b], may be nullptr float* __restrict__ attn_sink; // [h_q], may be nullptr float* __restrict__ lse; // [b, s_q, h_q] cutlass::bfloat16_t* __restrict__ out; // [b, s_q, h_q, d_v] int extra_num_blocks, extra_page_block_size, extra_topk; cutlass::bfloat16_t* __restrict__ extra_kv; // [extra_num_blocks, extra_page_block_size, d_qk] int* __restrict__ extra_indices; // [b, s_q, extra_topk] int* __restrict__ extra_topk_length; // [b], may be nullptr int stride_q_b, stride_q_s_q, stride_q_h_q; int stride_kv_block, stride_kv_row; int stride_indices_b, stride_indices_s_q; int stride_lse_b, stride_lse_s_q; int stride_o_b, stride_o_s_q, stride_o_h_q; int stride_extra_kv_block, stride_extra_kv_row; int stride_extra_indices_b, stride_extra_indices_s_q; cudaStream_t stream; // SplitKV-related parameters float* __restrict__ lse_accum; // [num_splits, s_q, h_q] float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v] int stride_lse_accum_split, stride_lse_accum_s_q; int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous int num_sm_parts; }; struct CombineParams { int b, s_q, h_q, d_v; float* __restrict__ lse; // [b, s_q, h_q] void* __restrict__ out; // [b, s_q, h_q, d_v] int stride_lse_b, stride_lse_s_q; int stride_o_b, stride_o_s_q, stride_o_h_q; float* __restrict__ lse_accum; // [num_splits, s_q, h_q] float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v] int stride_lse_accum_split, stride_lse_accum_s_q; int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous int num_sm_parts; float* attn_sink; // [h_q], may be nullptr cudaStream_t stream; }; struct GetDecodeSchedMetaParams { int b; // batch size int s_q; int block_size_n; int fixed_overhead_num_blocks; int topk, extra_topk; // -1 if sparse attention (or extra topk) is disabled int *__restrict__ topk_length, *__restrict__ extra_topk_length; int *__restrict__ seqlens_k_ptr; // Only necessary for dense attention DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; int *__restrict__ num_splits_ptr; int num_sm_parts; cudaStream_t stream; }; struct SparseAttnFwdParams { int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; float sm_scale, sm_scale_div_log2; // Input tensors cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk] cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk] int* __restrict__ indices; // [s_q, h_kv, topk] float* __restrict__ attn_sink; // [h_q], may be nullptr int* __restrict__ topk_length; // [s_q], may be nullptr // Strides int stride_q_s_q; int stride_q_h_q; int stride_kv_s_kv; int stride_kv_h_kv; int stride_indices_s_q; int stride_indices_h_kv; // Output tensors cutlass::bfloat16_t* __restrict__ out; // [s_q, h_q, d_v] float* __restrict__ max_logits; // [s_q, h_q] float* __restrict__ lse; // [s_q, h_q] int num_sm; cudaStream_t stream; }; // We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes. enum class SparseAttnFwdMode { Prefill, // Normal prefill mode DecodeWithSplitKV, // To trigger decoding mode for kernels that support both prefill and decode }; template inline constexpr bool is_decode_v = std::bool_constant::value; template using SparseFwdArgT = std::conditional_t, SparseAttnDecodeParams, SparseAttnFwdParams>; ================================================ FILE: csrc/sm100/decode/head128/README.md ================================================ Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel ================================================ FILE: csrc/sm100/decode/head64/config.h ================================================ #pragma once #include "kernel.h" #include #include #include #include #include "defines.h" #include "params.h" namespace sm100::decode::head64 { using cutlass::arch::fence_view_async_shared; using cutlass::arch::NamedBarrier; using e8m0 = __nv_fp8_e8m0; using e4m3 = cutlass::float_e4m3_t; using namespace cute; enum NamedBarriers : uint32_t { main_loop_sync = 0, wg0_sync = 1, wg0_warp02_sync = 2, wg0_warp13_sync = 3, everyone_sync = 4 }; template struct KernelTemplate { static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512; static constexpr int D_K = D_Q; static constexpr int D_V = 512; static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448; static constexpr int D_ROPE = 64; static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true; static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks. static_assert(D_NOPE + D_ROPE == D_Q); static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V)); static constexpr int B_H = 64; static constexpr int B_TOPK = 64; static constexpr int NUM_BUFS = 2; static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN static constexpr int D_Q_SW128 = 512; static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0; static_assert(D_Q_SW128 + D_Q_SW64 == D_Q); static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes template< typename Shape_Q_SW128, typename TMA_Q_SW128, typename Shape_O, typename TMA_O > struct TmaParams { Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128; Shape_O shape_O; TMA_O tma_O; CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0 CUtensorMap tensor_map_kv_nope; CUtensorMap tensor_map_kv_rope; CUtensorMap tensor_map_extra_kv_nope; CUtensorMap tensor_map_extra_kv_rope; }; // Tensor memory columns struct tmem_cols { // 0 ~ 256: output // 256 ~ 256 + 64*D_Q/256: Q // 400 ~ 464: P static constexpr int O = 0; static constexpr int Q = 256; static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128; static constexpr int P = 400; }; template using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutQ_SW128 = SmemLayoutQTiles; using SmemLayoutOBuf = decltype(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using SmemLayoutOBuf_TMA = decltype(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64>>{} )); // A TMA tile static_assert(D_V == 512); using SmemLayoutOAccumBuf = Layout< Shape, Int>, Stride, _1> // We use stride = 520 here to avoid bank conflict >; using SmemLayoutS = decltype(tile_to_shape( UMMA::Layout_K_INTER_Atom{}, Shape, Int>{}, Step<_1, _2>{} )); template using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTilesTransposed_SW128 = decltype(composition( SmemLayoutKTiles_SW128{}, Layout< Shape, Int>, Stride, _1> >{} )); template using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<32*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<32*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTilesTransposed_SW64 = decltype(composition( SmemLayoutKTiles_SW64{}, Layout< Shape, Int>, Stride, _1> >{} )); struct SharedMemoryPlan { union { struct { array_aligned> q; bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment. union { array_aligned> o_buf; array_aligned> o_accum_buf; } o; } qo; struct { struct { array_aligned nope; // NoPE part, dequantized array_aligned rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode } dequant[NUM_BUFS]; static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope array_aligned raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part } kv; } u; union { float4 p_exchange_buf[4][16 * B_TOPK / 4]; array_aligned> s; } s_p; CUTE_ALIGNAS(16) float rowwise_max_buf[128]; char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8]; int tma_coord[NUM_INDEX_BUFS][B_TOPK]; e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN]; array_aligned tmem_start_addr; transac_bar_t bar_last_store_done; transac_bar_t bar_q_tma, bar_q_utccp; transac_bar_t bar_rope_ready[NUM_BUFS]; transac_bar_t bar_nope_ready[NUM_BUFS]; transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS]; transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS]; transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS]; }; using TiledMMA_P = decltype(make_tiled_mma( SM100_MMA_F16BF16_WS_TS_NOELECT{} )); // *2 for dual gemm using TiledMMA_O = decltype(make_tiled_mma( SM100_MMA_F16BF16_WS_SS_NOELECT{} )); template static __device__ void flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params); static void run(const SparseAttnDecodeParams ¶ms); }; } ================================================ FILE: csrc/sm100/decode/head64/instantiations/model1.cu ================================================ #include "../kernel.cuh" namespace sm100::decode::head64 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm100/decode/head64/instantiations/v32.cu ================================================ #include "../kernel.cuh" namespace sm100::decode::head64 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm100/decode/head64/kernel.cuh ================================================ #include "kernel.h" #include #include #include #include #include #include #include "kerutils/kerutils.cuh" #include "utils.h" #include "sm100/helpers.h" #include "config.h" namespace sm100::decode::head64 { template template __device__ void KernelTemplate ::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params) { #if defined(KERUTILS_ENABLE_SM100A) const int s_q_idx = blockIdx.x; const int partition_idx = blockIdx.y; const int warpgroup_idx = cutlass::canonical_warp_group_idx(); const int idx_in_warpgroup = threadIdx.x % 128; const int warp_idx = cutlass::canonical_warp_idx_sync(); const int lane_idx = threadIdx.x % 32; extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); if (warp_idx == 0 && elect_one_sync()) { cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64); cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope); cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope); } if (warp_idx == 0) { if (elect_one_sync()) { plan.bar_last_store_done.init(128); plan.bar_q_tma.init(1); plan.bar_q_utccp.init(1); for (int i = 0; i < NUM_BUFS; ++i) { plan.bar_rope_ready[i].init(1); plan.bar_nope_ready[i].init(128); plan.bar_raw_ready[i].init(1); plan.bar_raw_free[i].init(128); plan.bar_qk_done[i].init(1); plan.bar_so_ready[i].init(128); plan.bar_sv_done[i].init(1); } for (int i = 0; i < NUM_INDEX_BUFS; ++i) { plan.bar_valid_coord_scale_ready[i].init(32); plan.bar_valid_coord_scale_free[i].init(128+128+1+1); } cutlass::arch::fence_barrier_init(); } cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); cute::TMEM::Allocator1Sm().release_allocation_lock(); } __syncthreads(); struct MainLoopArgs { int batch_idx, start_block_idx, end_block_idx; bool is_no_split; int n_split_idx; bool bar_phase_batch_rel; // Bar phase of barriers that are used once per batch int topk_length, extra_topk_length, num_orig_kv_blocks; bool is_last_batch; }; auto run_main_loop = [&](auto f) { // NOTE Putting the following code outside the warpgroup specialization switch results in register spilling. // [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted; DecodingSchedMeta sched_meta; KU_LDG_256( params.tile_scheduler_metadata_ptr + partition_idx, &sched_meta, ".nc", "no_allocate", "evict_normal", "256B" ); if (sched_meta.begin_req_idx >= params.b) { return; } bool bar_phase_batch_rel = 0; #pragma unroll 1 for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) { int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK; bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false); int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx); MainLoopArgs args = { batch_idx, start_block_idx, end_block_idx, !is_split, n_split_idx, bar_phase_batch_rel, topk_length, extra_topk_length, orig_topk_padded / B_TOPK, batch_idx == sched_meta.end_req_idx }; f(args); NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned(); } }; struct RingState { int buf_idx = 0; bool bar_phase = 0; int index_buf_idx = 0; bool index_bar_phase = 0; CUTE_DEVICE void update() { bar_phase ^= (buf_idx == NUM_BUFS-1); buf_idx = (buf_idx+1) % NUM_BUFS; index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1); index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS; } }; RingState rs; if (warpgroup_idx == 0) { // Scale & Exp warpgroup // The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel cutlass::arch::warpgroup_reg_alloc<224>(); constexpr int B_EPI = 64; // Must be equal to the size of the swizzle atom Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{}); bf16* sO_bases[B_EPI/8]; // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8); const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2); float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F; run_main_loop([&](const MainLoopArgs &args) { cute::tma_store_wait<0>(); plan.bar_last_store_done.arrive(); float mi = MAX_INIT_VAL; float li = 0.0f; float real_mi = -CUDART_INF_F; CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); // Put the barrier wait here for more code reordering space plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase); ku::tcgen05_after_thread_sync(); // Load P float p[B_TOPK/2], p_peer[B_TOPK/2]; if (warp_idx < 2) { ku::tmem_ld_32dp32bNx(tmem_cols::P, p); ku::tmem_ld_32dp32bNx(tmem_cols::P+32, p_peer); } else { ku::tmem_ld_32dp32bNx(tmem_cols::P, p_peer); ku::tmem_ld_32dp32bNx(tmem_cols::P+32, p); } cutlass::arch::fence_view_async_tmem_load(); ku::tcgen05_before_thread_sync(); // Reduce within shared mem { // Store // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/4; ++i) plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4); NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3 // Load CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/4; ++i) { float2 t[2]; *(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx]; float2* cur_p = (float2*)(p + i*4); cur_p[0] = ku::float2_add(cur_p[0], t[0]); cur_p[1] = ku::float2_add(cur_p[1], t[1]); } } // Since dual gemm is utilized, the layout of P in register now look like: // // 32 32 // +-------+-------+ // | | | // 32 | Warp0 | Warp2 | // | | | // +-------+-------+ // | | | // 32 | Warp1 | Warp3 | // | | | // +-------+-------+ // Mask uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0)); CUTE_UNROLL for (int i = 0; i < B_TOPK/2; i += 1) { if (!(valid_mask>>i&1)) p[i] = -CUDART_INF_F; } // Get rowwise max of Pi float cur_pi_max = -CUDART_INF_F; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2); i += 1) { cur_pi_max = max(cur_pi_max, p[i]); } cur_pi_max *= params.sm_scale_div_log2; plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // This also separates "reading p_exchange_buf" and "writing S" plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); real_mi = max(real_mi, cur_pi_max); bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); // By this point: // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127) // Calc scale factor, and scale li float new_max, scale_for_old; if (!should_scale_o) { // Don't scale O scale_for_old = 1.0f; new_max = mi; } else { new_max = max(cur_pi_max, mi); scale_for_old = exp2f(mi - new_max); } mi = new_max; // mi is still identical within each row // Calculate S __nv_bfloat162 s[(B_TOPK/2)/2]; float2 neg_new_max = float2 {-new_max, -new_max}; float2 cur_sum = float2 {0.0f, 0.0f}; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/2; i += 1) { float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale, neg_new_max); d.x = exp2f(d.x); d.y = exp2f(d.y); cur_sum = ku::float2_add(cur_sum, d); s[i] = __float22bfloat162_rn(d); } li = fma(li, scale_for_old, (cur_sum.x + cur_sum.y)); // Write S CUTE_UNROLL for (int i = 0; i < B_TOPK/2/8; i += 1) { *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4); } // Scale O if (block_idx != args.start_block_idx && should_scale_o) { float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; ku::tcgen05_after_thread_sync(); static constexpr int CHUNK_SIZE = 64; float2 o[CHUNK_SIZE/2]; CUTE_UNROLL for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { // Load O ku::tmem_ld_32dp32bNx(tmem_cols::O + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_load(); // Mult for (int i = 0; i < CHUNK_SIZE/2; ++i) { o[i] = ku::float2_mul(o[i], scale_for_old_float2); } // Store O ku::tmem_st_32dp32bNx(tmem_cols::O + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_store(); } ku::tcgen05_before_thread_sync(); } fence_view_async_shared(); plan.bar_so_ready[rs.buf_idx].arrive(); if (block_idx != args.end_block_idx-1) { rs.update(); // Don't update rs for the last round since we want to wait for the last SV gemm } } if (real_mi == -CUDART_INF_F) { // real_mi == -CUDART_INF_F <=> No valid TopK indices // We set li to 0 to fit the definition that li := exp(x[i] - mi) li = 0.0f; mi = -CUDART_INF_F; } // Exchange li plan.rowwise_max_buf[idx_in_warpgroup] = li; NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); li += plan.rowwise_max_buf[idx_in_warpgroup^64]; // Store li if (idx_in_warpgroup < B_H) { if (args.is_no_split) { float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; float* gSoftmaxLse = (float*)params.lse + args.batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + idx_in_warpgroup; *gSoftmaxLse = cur_lse; } else { float cur_lse = log2f(li) + mi; float* gSoftmaxLseAccum = (float*)params.lse_accum + args.n_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + idx_in_warpgroup; *gSoftmaxLseAccum = cur_lse; } } plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase); rs.update(); ku::tcgen05_after_thread_sync(); if (args.is_last_batch) { cudaTriggerProgrammaticLaunchCompletion(); } if (args.is_no_split) { Tensor tma_gO = flat_divide( tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, args.batch_idx), Shape, Int<64>>{} )(_, _, _0{}, _); auto thr_tma = tma_params.tma_O.get_slice(_0{}); Tensor tma_sO = flat_divide( sO, Shape, Int<64>>{} )(_, _, _0{}, _); float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li + exp2f(attn_sink - mi)); float2 o_scale_float2 = {o_scale, o_scale}; float2 o[B_EPI/2]; __nv_bfloat162 o_bf16[B_EPI/2]; CUTE_UNROLL for (int i = 0; i < (D_V/2) / B_EPI; ++i) { // Load ku::tmem_ld_32dp32bNx(tmem_cols::O + i*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); // Scale & Convert CUTE_UNROLL for (int j = 0; j < B_EPI/2; ++j) { o[j] = ku::float2_mul(o[j], o_scale_float2); o_bf16[j] = __float22bfloat162_rn(o[j]); } // Store int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); CUTE_UNROLL for (int j = 0; j < B_EPI / 8; ++j) *(__int128_t*)(sO_bases[j] + col_base*B_H) = *(__int128_t*)(&o_bf16[j*4]); // Sync fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // S -> G if (warp_idx == 0 && elect_one_sync()) { cute::copy( tma_params.tma_O, thr_tma.partition_S(tma_sO(_, _, col_base/64)), thr_tma.partition_D(tma_gO(_, _, col_base/64)) ); } if (warp_idx == 1 && elect_one_sync()) { cute::copy( tma_params.tma_O, thr_tma.partition_S(tma_sO(_, _, col_base/64 + (D_V/4)/64)), thr_tma.partition_D(tma_gO(_, _, col_base/64 + (D_V/4)/64)) ); } } cute::tma_store_arrive(); } else { float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li); // Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times float2 o_scale_float2 = {o_scale, o_scale}; constexpr int B_EPI = 64; float2 o[B_EPI/2]; Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_accum_buf.data()), SmemLayoutOAccumBuf{}); CUTE_UNROLL for (int i = 0; i < (D_V/2) / B_EPI; ++i) { // Load ku::tmem_ld_32dp32bNx(tmem_cols::O + i*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); // Scale & Convert CUTE_UNROLL for (int j = 0; j < B_EPI/2; ++j) o[j] = ku::float2_mul(o[j], o_scale_float2); // Store int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); CUTE_UNROLL for (int j = 0; j < B_EPI / 4; ++j) *(__int128_t*)&sO(idx_in_warpgroup%64, col_base + j*4) = *(__int128_t*)(&o[j*2]); } fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); if (elect_one_sync()) { CUTE_UNROLL for (int local_row = 0; local_row < B_H/4; ++local_row) { int smem_row = local_row*4 + warp_idx; SM90_BULK_COPY_S2G::copy( &sO(smem_row, _0{}), (float*)params.o_accum + args.n_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + smem_row*params.stride_o_accum_h_q, D_V*sizeof(float) ); } cute::tma_store_arrive(); } } }); if (warp_idx == 0) { cute::TMEM::Allocator1Sm().free(0, 512); } } else if (warpgroup_idx == 1) { cutlass::arch::warpgroup_reg_dealloc<72>(); const int warp_idx = cutlass::canonical_warp_idx_sync(); // Missing this leads to reg spilling if (warp_idx == 4 && elect_one_sync()) { // MMA Warp run_main_loop([&](const MainLoopArgs &args) { if (args.start_block_idx >= args.end_block_idx) { ku::trap(); } // Issue Q (SW128) G->S { Tensor gQ = tma_params.tma_Q_SW128.get_tma_tensor(tma_params.shape_Q_SW128)(_, _, s_q_idx, args.batch_idx); Tensor sQ = make_tensor(make_smem_ptr(plan.u.qo.q.data()), SmemLayoutQ_SW128{}); ku::launch_tma_copy( tma_params.tma_Q_SW128, gQ, sQ, plan.bar_q_tma, TMA::CacheHintSm90::EVICT_FIRST ); } // Issue Q (SW64) G -> S if constexpr (D_Q_SW64 > 0) { cute::SM90_TMA_LOAD_5D::copy( &tma_params.tensor_map_q_sw64, (uint64_t*)&plan.bar_q_tma, (uint64_t)TMA::CacheHintSm90::EVICT_FIRST, plan.u.qo.q_sw64, 0, 0, 0, s_q_idx, args.batch_idx ); } plan.bar_q_tma.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); plan.bar_q_tma.wait(args.bar_phase_batch_rel); ku::tcgen05_after_thread_sync(); // Issue Q (SW128) UTCCP { UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(plan.u.qo.q.data()), tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64>>{} // *2 to leverage dual GEMM ) ) ); static_assert(D_Q_SW128%128 == 0); CUTE_UNROLL for (int tile_idx = 0; tile_idx < D_Q_SW128/128; ++tile_idx) { // Each tile: 64 x (64*2) logically, 128 x 64 bf16 on TMEM CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 64/16; ++subtile_idx) { // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM SM100_UTCCP_128dp256bit_1cta::copy( sQ_desc + (tile_idx*(B_H*128) + subtile_idx*16) * 2 / 16, tmem_cols::Q + tile_idx*32 + subtile_idx*8 ); } } } // Issue Q (SW64) UTCCP if constexpr (D_Q_SW64 > 0) { UMMA::SmemDescriptor sQ_SW64_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(plan.u.qo.q_sw64), tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<32>>{} // *2 to leverage dual GEMM ) ) ); static_assert(D_Q_SW64%64 == 0); CUTE_UNROLL for (int tile_idx = 0; tile_idx < D_Q_SW64/64; ++tile_idx) { // Each tile: 64 x (32*2) logically, 128 x 32 bf16 on TMEM CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 32/16; ++subtile_idx) { // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM SM100_UTCCP_128dp256bit_1cta::copy( sQ_SW64_desc + (tile_idx*(B_H*64) + subtile_idx*16) * 2 / 16, tmem_cols::Q + (B_H*D_Q_SW128/2/128) + tile_idx*16 + subtile_idx*8 ); } } } ku::umma_arrive_noelect(plan.bar_q_utccp); // Allocate tmem tensors TiledMMA tiled_mma_P = TiledMMA_P{}; TiledMMA tiled_mma_O = TiledMMA_O{}; // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm) Tensor tP = partition_fragment_C(tiled_mma_P, Shape, _128>{}); Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); tP.data().get() = tmem_cols::P; tO.data().get() = tmem_cols::O; // Wait for UTCCP plan.bar_q_utccp.wait(args.bar_phase_batch_rel); ku::tcgen05_after_thread_sync(); // Mainloop CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { if constexpr (MODEL_TYPE == ModelType::V32) { // V3.2: RoPE behaves like an extra block with size 64, so we can do RoPE first // QK RoPE plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase); ku::tcgen05_after_thread_sync(); Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int>{}) ); tQ_rope.data().get() = tmem_cols::Q_Tail; Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].rope.data()), SmemLayoutKTiles_DualGemm_SW64<2/2>{}); ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true); // QK NoPE plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase); ku::tcgen05_after_thread_sync(); Tensor tQ_nope = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int>{}) ); tQ_nope.data().get() = tmem_cols::Q; Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{}); ku::utcmma_ts(tiled_mma_P, tQ_nope, sK_nope, tP, false); } else { // MODEL1: RoPE is the last 64 dims within the full 512 dim, which couples with the last 64 dim from the NoPE part when performing dual GEMM. i.e. // // logical view: |0|1|2|3|4|5|6|7| (where 7 is the RoPE part) // dual gemm's view: // |0|2|4|6| // |1|3|5|7| // // So we must wait for both the NoPE and the RoPE part, and then perform dual GEMM plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase); plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase); ku::tcgen05_after_thread_sync(); Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int>{}) ); tQ.data().get() = tmem_cols::Q; Tensor sK = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{}); ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true); } ku::umma_arrive_noelect(plan.bar_qk_done[rs.buf_idx]); // SV plan.bar_so_ready[rs.buf_idx].wait(rs.bar_phase); ku::tcgen05_after_thread_sync(); Tensor sS = make_tensor(make_smem_ptr(plan.s_p.s.data()), SmemLayoutS{}); Tensor sV = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTilesTransposed_SW128{}); // NOTE: For MODEL1, it "expands" to the RoPE part. ku::utcmma_ss(tiled_mma_O, sS, sV, tO, block_idx == args.start_block_idx); ku::umma_arrive_noelect(plan.bar_sv_done[rs.buf_idx]); rs.update(); } }); } else if (warp_idx == 5 && elect_one_sync()) { // Raw KV NoPE retrieval warp run_main_loop([&](const MainLoopArgs &args) { plan.bar_q_utccp.wait(args.bar_phase_batch_rel); plan.bar_last_store_done.wait(args.bar_phase_batch_rel); CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); plan.bar_raw_free[rs.buf_idx].wait(rs.bar_phase^1); int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0); int4 nxt_cur_indices; CUTE_UNROLL for (int row = 0; row < B_TOPK; row += 4) { if (row+4 < B_TOPK) nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4); ku::tma_gather4( block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope, plan.bar_raw_ready[rs.buf_idx], plan.u.kv.raw_nope[rs.buf_idx].data() + D_NOPE*row, 0, cur_indices, (int64_t)TMA::CacheHintSm90::EVICT_LAST ); cur_indices = nxt_cur_indices; } plan.bar_raw_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_NOPE*sizeof(e4m3)); plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); rs.update(); } }); } else if (warp_idx == 6 && elect_one_sync()) { // KV RoPE retrieval warp run_main_loop([&](const MainLoopArgs &args) { plan.bar_q_utccp.wait(args.bar_phase_batch_rel); plan.bar_last_store_done.wait(args.bar_phase_batch_rel); CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); if constexpr (MODEL_TYPE == ModelType::V32) { plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase^1); } else { plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1); } int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0); int4 nxt_cur_indices; CUTE_UNROLL for (int row = 0; row < B_TOPK; row += 4) { if (row+4 < B_TOPK) nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4); CUTE_UNROLL for (int t = 0; t < D_ROPE/(K_ROPE_SW/2); ++t) { ku::tma_gather4( block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope, plan.bar_rope_ready[rs.buf_idx], plan.u.kv.dequant[rs.buf_idx].rope.data() + (K_ROPE_SW/2)*row + t*B_TOPK*(K_ROPE_SW/2), t*(K_ROPE_SW/2), cur_indices, (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } cur_indices = nxt_cur_indices; } plan.bar_rope_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16)); plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); rs.update(); } }); } else if (warp_idx == 7) { // Indices transformation warp // Responsible for generating: TMA coordinates, scale factors, and valid masks static_assert(B_TOPK == 64); static constexpr int tma_coords_step_per_token = MODEL_TYPE == ModelType::V32 ? 656/TMA_K_STRIDE : 576/TMA_K_STRIDE; int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE > 512 int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE; uint8_t* k_scales_ptr = MODEL_TYPE == ModelType::V32 ? (uint8_t*)params.kv + D_NOPE : (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE); uint8_t* extra_k_scales_ptr = MODEL_TYPE == ModelType::V32 ? (uint8_t*)params.extra_kv + D_NOPE : (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE); run_main_loop([&](const MainLoopArgs &args) { int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*s_q_idx; int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*s_q_idx; struct IsOrigBlock {}; struct IsExtraBlock {}; auto process_one_block = [&](int block_idx, auto is_extra_block_t) { static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size; int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block; [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row; uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr; int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block; int abs_pos, my_indices[2]; if (!IS_EXTRA_BLOCK) { abs_pos = block_idx*B_TOPK + lane_idx*2; *(int2*)my_indices = __ldg((int2*)(indices + abs_pos)); } else { abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2; *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos)); } plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1); int tma_coords[2]; e8m0 scales[2*NUM_SCALES_EACH_TOKEN]; char valid_mask = 0; CUTE_UNROLL for (int i = 0; i < 2; ++i) { int block_idx, idx_in_block; block_idx = (unsigned int)my_indices[i] / cur_block_size; idx_in_block = (unsigned int)my_indices[i] % cur_block_size; bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length)); valid_mask |= is_token_valid << i; tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN. if constexpr (MODEL_TYPE == ModelType::V32) { int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0; float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset)); e8m0 res[4]; *(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero); *(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero); if (!is_token_valid) *(uint32_t*)res = (uint32_t)0; *(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res); } else { int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0; *(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = scalesx8; } } valid_mask <<= lane_idx%4*2; valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1); valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2); if constexpr (MODEL_TYPE == ModelType::V32) { *(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales; } else { *(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales; } *(int2*)(plan.tma_coord[rs.index_buf_idx] + lane_idx*2) = *(int2*)tma_coords; if (lane_idx%4 == 0) plan.is_token_valid[rs.index_buf_idx][lane_idx/4] = valid_mask; plan.bar_valid_coord_scale_ready[rs.index_buf_idx].arrive(); rs.update(); }; CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { process_one_block(block_idx, IsOrigBlock{}); } CUTE_NO_UNROLL for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) { process_one_block(block_idx, IsExtraBlock{}); } }); } else { run_main_loop([&](const MainLoopArgs &args) {}); } } else { // Dequant warpgroup cutlass::arch::warpgroup_reg_alloc<208>(); // 8 threads per token constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = D_NOPE/(GROUP_SIZE*8); int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE; Tensor nope0 = make_tensor(make_smem_ptr(plan.u.kv.dequant[0].nope.data()), SmemLayoutKTiles_SW128{}); bf16* nope0_base = &nope0(group_idx, idx_in_group*8); bf16* nope1_base = nope0_base + (plan.u.kv.dequant[1].nope.data() - plan.u.kv.dequant[0].nope.data()); e4m3* raw_nope0_base = plan.u.kv.raw_nope[rs.buf_idx].data() + group_idx*D_NOPE + idx_in_group*8; e4m3* raw_nope1_base = raw_nope0_base + B_H*D_NOPE; run_main_loop([&](const MainLoopArgs &args) { // plan.bar_last_store_done.wait(args.bar_phase_batch_rel); // No need to wait since the raw nope producer must wait plan.bar_q_utccp.wait(args.bar_phase_batch_rel); CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); plan.bar_raw_ready[rs.buf_idx].wait(rs.bar_phase); plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1); uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(rs.buf_idx == 0 ? nope0_base : nope1_base); e4m3* raw_nope_base = rs.buf_idx == 0 ? raw_nope0_base : raw_nope1_base; auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) { asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n" : : "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16) ); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS }; auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t { return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*D_NOPE + local_col_idx*(GROUP_SIZE*8)); }; // The following code suffers from a 2-way bank conflict when reading from SMEM. if constexpr (MODEL_TYPE == ModelType::V32) { CUTE_UNROLL for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { int row_idx = local_row_idx*NUM_GROUPS + group_idx; bf16 scales[4]; e8m0 scales_e8m0[4]; *(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx]; *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); CUTE_UNROLL for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { ku::nve4m3x2 data_fp8[4]; ku::nvbf16x2 data_bf16[4]; *(uint64_t*)data_fp8 = cur_data_fp8x8; if (local_col_idx+1 < COLS_PER_GROUP) cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); bf16 scale = scales[local_col_idx / (D_NOPE/(GROUP_SIZE*8)/4)]; CUTE_UNROLL for (int i = 0; i < 4; ++i) { data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); } st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); } } } else { CUTE_UNROLL for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { int row_idx = local_row_idx*NUM_GROUPS + group_idx; bf16 scales[8]; e8m0 scales_e8m0[8]; *(uint64_t*)scales_e8m0 = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx]; *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); *(__nv_bfloat162_raw*)(scales+4) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+4)); *(__nv_bfloat162_raw*)(scales+6) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+6)); uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); CUTE_UNROLL for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { ku::nve4m3x2 data_fp8[4]; ku::nvbf16x2 data_bf16[4]; *(uint64_t*)data_fp8 = cur_data_fp8x8; if (local_col_idx+1 < COLS_PER_GROUP) cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); bf16 scale = scales[local_col_idx]; CUTE_UNROLL for (int i = 0; i < 4; ++i) { data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); } st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); } } } cutlass::arch::fence_view_async_shared(); plan.bar_nope_ready[rs.buf_idx].arrive(); plan.bar_raw_free[rs.buf_idx].arrive(); plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); rs.update(); } }); } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119"); } #endif } template __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1) flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) { Kernel::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(params, tma_params); } template void KernelTemplate::run(const SparseAttnDecodeParams ¶ms) { KU_ASSERT(params.topk % B_TOPK == 0, "topk (%d) mod B_TOPK (%d) must be 0", params.topk, B_TOPK); KU_ASSERT(params.extra_topk % B_TOPK == 0, "extra_topk (%d) mod B_TOPK (%d) must be 0", params.extra_topk, B_TOPK); KU_ASSERT(params.h_q == B_H); KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.d_qk == D_Q); KU_ASSERT(params.d_v == D_V); if constexpr (MODEL_TYPE == ModelType::MODEL1) { constexpr int BYTES_PER_TOKEN = D_NOPE + 2*D_ROPE + 8; KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous } auto shape_Q_SW128 = make_shape(B_H, D_Q, params.s_q, params.b); auto tma_Q_SW128 = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((bf16*)params.q), make_layout( shape_Q_SW128, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b) ) ), SmemLayoutQ_SW128{} ); auto shape_O = make_shape(B_H, D_V, params.s_q, params.b); auto tma_O = cute::make_tma_copy( SM90_TMA_STORE{}, make_tensor( make_gmem_ptr((bf16*)params.out), make_layout( shape_O, make_stride(params.stride_o_h_q, _1{}, params.stride_o_s_q, params.stride_o_b) ) ), SmemLayoutOBuf_TMA{} ); CUtensorMap tensor_map_q_sw64{}; if constexpr (D_Q_SW64 > 0) { tensor_map_q_sw64 = ku::make_tensor_map( {D_Q_SW64, (uint64_t)params.h_q, D_Q_SW64/32, (uint64_t)params.s_q, (uint64_t)params.b}, ku::make_stride_helper(std::vector{params.stride_q_h_q, (int64_t)32, params.stride_q_s_q, params.stride_q_b}, sizeof(bf16)), {32, B_H, D_Q_SW64/32, 1, 1}, (bf16*)params.q + D_Q_SW128, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B ); } auto get_nope_rope_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t k_batch_stride) -> std::pair { static_assert(D_NOPE%8 == 0); KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr); KU_ASSERT(k_batch_stride % TMA_K_STRIDE == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", k_batch_stride, TMA_K_STRIDE); CUtensorMap tensor_map_kv_nope = ku::make_tensor_map( {D_NOPE/8, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)}, {TMA_K_STRIDE}, {D_NOPE/8, 1}, k_ptr, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT64, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B ); // NOTE We combine 8 float8 into 1 int64 since boxdim cannot > 256 CUtensorMap tensor_map_kv_rope = ku::make_tensor_map( {D_ROPE, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)}, {TMA_K_STRIDE}, {K_ROPE_SW/2, 1}, (uint8_t*)k_ptr + (MODEL_TYPE == ModelType::V32 ? (D_NOPE+16) : D_NOPE), CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, K_ROPE_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B ); return {tensor_map_kv_nope, tensor_map_kv_rope}; }; auto [tensor_map_kv_nope, tensor_map_kv_rope] = get_nope_rope_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block); CUtensorMap tensor_map_extra_kv_nope{}, tensor_map_extra_kv_rope{}; if (params.extra_topk > 0) { std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_nope_rope_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block); } TmaParams< decltype(shape_Q_SW128), decltype(tma_Q_SW128), decltype(shape_O), decltype(tma_O) > tma_params = { shape_Q_SW128, tma_Q_SW128, shape_O, tma_O, tensor_map_q_sw64, tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope, tensor_map_extra_kv_rope }; auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel, decltype(tma_params)>; constexpr size_t smem_size = sizeof(SharedMemoryPlan); static_assert(smem_size < 227*1024); KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // NOTE Don't use PDL because of potential compiler bugs! mla_kernel<<>>(params, tma_params); KU_CHECK_KERNEL_LAUNCH(); } template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms) { KernelTemplate::run(params); } } ================================================ FILE: csrc/sm100/decode/head64/kernel.h ================================================ #pragma once #include "params.h" namespace sm100::decode::head64 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm100/helpers.h ================================================ #pragma once #include #include #include #include "defines.h" namespace sm100 { using namespace cute; CUTE_DEVICE int int4_max(int4 t) { return max(max(t.x, t.y), max(t.z, t.w)); } CUTE_DEVICE int int4_min(int4 t) { return min(min(t.x, t.y), min(t.z, t.w)); } // Convert 2x fp8_e4m3 to 2x bf16 with scaling CUTE_DEVICE nv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) { // TODO Use native conversion for CUDA >= 13.1 float2 data_float2 = (float2)data; nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2); return nv_bfloat162 { data_bf16x2.x * scale, data_bf16x2.y * scale }; } } ================================================ FILE: csrc/sm100/prefill/dense/collective/fmha_common.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/kernel_hardware_info.h" #include "cutlass/arch/reg_reconfig.h" #include "cute/tensor.hpp" namespace cutlass::fmha::collective { using namespace cute; template CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { constexpr int rA = decltype(rank(tA))::value; constexpr int rB = decltype(rank(tB))::value; constexpr int rC = decltype(rank(tC))::value; static_assert(rA == 3 && rB == 3 && rC == 3); CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tA); k_block++) { cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); atom.accumulate_ = decltype(atom.accumulate_)::One; } } template CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; gemm_reset_zero_acc(atom, tA, tB, tC); } template CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { return composition(layout, prepend(make_layout(stages), _)); } template CUTE_DEVICE T warp_uniform(T a) { return __shfl_sync(0xffffffff, a, 0); } template CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( TiledMMA, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, cute::integral_constant>, TAs...>, TMs...>) { return TiledMMA>, TAs...>, TMs...>{}; } template CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( TiledMMA, TAs...>, TMs...>) { return TiledMMA, TAs...>, TMs...>{}; } template CUTLASS_DEVICE void warpgroup_reg_set() { if constexpr (RegCount < 128) { cutlass::arch::warpgroup_reg_dealloc(); } else { cutlass::arch::warpgroup_reg_alloc(); } } } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/collective/fmha_fusion.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/tensor.hpp" namespace cutlass::fmha::collective { using namespace cute; struct NoMask { template CUTLASS_DEVICE int get_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { return ceil_div(get<1>(problem_size), get<1>(tile_shape)); } template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { return 0; } template CUTLASS_DEVICE int get_unmasked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { return get_trip_count(blk_coord, tile_shape, problem_size); } template CUTLASS_DEVICE void apply_mask( AccQK& acc_qk, IndexQK const& index_qk, ProblemSize const& problem_size) { return; } }; struct ResidualMask : NoMask { using Base = NoMask; template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { if (get<1>(problem_size) % get<1>(tile_shape) != 0) { return 1; } return 0; } template CUTLASS_DEVICE int get_unmasked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { // if the sequence length does not divide the tile size evenly if (get<1>(problem_size) % get<1>(tile_shape) != 0) { return get_trip_count(blk_coord, tile_shape, problem_size) - 1; } return get_trip_count(blk_coord, tile_shape, problem_size); } template CUTLASS_DEVICE void apply_mask( AccQK& acc_qk, IndexQK const& index_qk, ProblemSize const& problem_size) { // This is useful is seqlen_k % kBlockN != 0 since it masks // the remaining elements out from softmax. // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar // issues as they are transparently taken care of by TMA and the // epilogue, if it is instantiated with predication support. CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); if (get<1>(pos) >= get<1>(problem_size)) { acc_qk(i) = -INFINITY; } } } }; struct ResidualMaskForBackward : NoMask { using Base = NoMask; template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { if (get<1>(problem_size) % get<1>(tile_shape) != 0) { return 1; } return 0; } template CUTLASS_DEVICE int get_unmasked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { // if the sequence length does not divide the tile size evenly if (get<1>(problem_size) % get<1>(tile_shape) != 0) { return get_trip_count(blk_coord, tile_shape, problem_size) - 1; } return get_trip_count(blk_coord, tile_shape, problem_size); } template CUTLASS_DEVICE void apply_mask( AccQK& acc_qk, IndexQK const& index_qk, ProblemSize const& problem_size) { // This is useful is seqlen_k % kBlockN != 0 since it masks // the remaining elements out from softmax. // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar // issues as they are transparently taken care of by TMA and the // epilogue, if it is instantiated with predication support. CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); if (! elem_less(pos, select<0,1>(problem_size))) { acc_qk(i) = -INFINITY; } } } }; // There are two ways to do causal if N_Q != N_K // (1) The Q is at the beginning of the matrix // (2) The Q is at the end of the matrix template struct CausalMask : NoMask { using Base = NoMask; static constexpr bool IsQBegin = kIsQBegin; template CUTLASS_DEVICE int get_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { // See note below on different ways to think about causal attention // Again, we'd add the offset_q into the max_blocks_q calculation int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); return std::min(max_blocks_k, max_blocks_q); } else { const int offset_q = get<1>(problem_size) - get<0>(problem_size); int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); return std::min(max_blocks_k, max_blocks_q); } } template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); } } template CUTLASS_DEVICE int get_unmasked_trip_count( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); } template CUTLASS_DEVICE void apply_mask( AccQK& acc_qk, IndexQK const& index_qk, ProblemSize const& problem_size) { // There are two ways to do causal if N_Q != N_K // (1) is to assume that the Q is at the beginning of the matrix // - this is the default setting. // (2) is that it is at the end of the matrix // - this is usually what we want for inference settings // where we only compute the next row and use cache for the rest // - if you'd like this, you only need to set kIsQBegin=false if constexpr (IsQBegin) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { acc_qk(i) = -INFINITY; } } } else { const auto offset_q = get<1>(problem_size) - get<0>(problem_size); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { acc_qk(i) = -INFINITY; } } } } }; template struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { using Base = CausalMask; template CUTLASS_DEVICE void apply_mask( AccQK& acc_qk, IndexQK const& index_qk, ProblemSize const& problem_size) { // There are two ways to do causal if N_Q != N_K // (1) is to assume that the Q is at the beginning of the matrix // - this is what we demonstrate here // (2) is that it is at the end of the matrix // - this is usually what we want for inference settings // where we only compute the next row and use cache for the rest // - if you'd like this, you only need to add an offset like so: // get<0>(pos) + offset_q < get<1>(pos) int offset_q = 0; if constexpr (!kIsQBegin) { offset_q = get<1>(problem_size) - get<0>(problem_size); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); if (masked) { acc_qk(i) = -INFINITY; } } } }; struct VariableLength { int max_length; int* cumulative_length = nullptr; int total_length = -1; CUTE_HOST_DEVICE operator int() const { return max_length; } }; template struct is_variable_length_impl : std::false_type {}; template<> struct is_variable_length_impl : std::true_type {}; template constexpr bool is_variable_length_v = is_variable_length_impl>::value; template CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Idx const& idx) { return transform_leaf(shape, [&](auto const& s) { if constexpr (is_variable_length_v) { return s.cumulative_length[idx+1] - s.cumulative_length[idx]; } else { return s; } }); } template CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { auto new_shape = apply_variable_length(shape, idx); auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { if constexpr (is_variable_length_v) { return cute::make_tuple(c, s.cumulative_length[idx]); } else { return c; } }); return cute::make_tuple(new_shape, new_coord); } template CUTE_HOST_DEVICE constexpr auto apply_variable_length_offset(Shape const& shape, Coord const& coord) { auto idx = back(back(coord)); auto result_shape = transform_leaf(shape, [&](auto const& s) { if constexpr (is_variable_length_v) { return s.cumulative_length[idx+1] - s.cumulative_length[idx]; } else { return s; } }); auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { if constexpr (is_variable_length_v) { return s.cumulative_length[idx]; } else { return _0{}; } }); return cute::make_tuple(result_shape, result_offset); } } // namespace cutlass::fmha::collective namespace cute { template<> struct is_integral : true_type {}; CUTE_HOST_DEVICE void print(cutlass::fmha::collective::VariableLength a) { printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); } } ================================================ FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/layout.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" namespace cutlass::fmha::collective { template< class Element, class ElementAcc, class TileShape, // Q, D, _ class StrideO, // Q, D, B class StrideLSE_, // Q, B class OrderLoadEpilogue = cute::false_type > struct Sm100FmhaFwdEpilogueTmaWarpspecialized { using Pipeline = cutlass::PipelineAsync<2>; // using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); // using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); using SmemLayoutO_ = SmemLayoutO; using StrideLSE = StrideLSE_; using ElementOut = Element; static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; struct TensorStorage { using SmemLayoutO = SmemLayoutO_; cute::array_aligned> smem_o; }; struct Arguments { Element* ptr_O; StrideO dO; ElementAcc* ptr_LSE; StrideLSE dLSE; }; using TMA_O = decltype(make_tma_copy( SM90_TMA_STORE{}, make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), SmemLayoutO{}(_,_,_0{}) )); struct Params { TMA_O tma_store_o; ElementAcc* ptr_LSE; StrideLSE dLSE; }; // FMHA and MLA have different input ProblemShapes; // get problem_shape_O according to the input ProblemShape. template CUTLASS_DEVICE static constexpr auto get_problem_shape_O ( ProblemShape const& problem_shape) { if constexpr (rank_v(ProblemShape{}))> == 2) { return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); } else { return select<0,2,3>(problem_shape); } } template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace = nullptr) { auto ptr_O = args.ptr_O; StrideO dO = args.dO; auto problem_shape_O = get_problem_shape_O(problem_shape); if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { int max_length_q = get<0>(problem_shape).max_length; get<0>(problem_shape_O).max_length = max(1, max_length_q); // for variable sequence lenght, the batch is in units of row_stride get<2,1>(dO) = get<0>(dO); get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O))); // offset ptr by the amount we add back in later ptr_O -= max_length_q * get<0>(dO); } } else { get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O)); } auto tma_store_o = make_tma_copy( SM90_TMA_STORE{}, make_tensor(ptr_O, problem_shape_O, dO), SmemLayoutO{}(_,_,_0{}) ); return { tma_store_o, args.ptr_LSE, args.dLSE }; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); } const Params& params; CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} template CUTLASS_DEVICE auto store( BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, Params const& params, ParamsProblemShape const& params_problem_shape, TensorStorage& shared_storage, Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { BlkCoord blk_coord = blk_coord_in; uint32_t lane_predicate = cute::elect_one_sync(); using X = Underscore; int o0_index = 2 * get<0>(blk_coord); int o1_index = 2 * get<0>(blk_coord) + 1; Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); // offset mode 0 by (max_length - real_length) // offset mode 3,1 by cumulative_length + real_length // the ptr is already offset by - max_length // so in total this achieves int offs_0 = 0; int offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { int max_length_q = get<0>(params_problem_shape).max_length; offs_0 = max_length_q - get<0>(problem_shape); offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); get<2,1>(blk_coord) = 0; } } Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); auto block_tma = params.tma_store_o.get_slice(0); Tensor tOsO = block_tma.partition_S(sO); Tensor tOgO = block_tma.partition_D(gO); auto pipeline_release_state = pipeline_consumer_state; // O1 O2 // one pipeline: O // wait from corr, issue tma store on smem pipeline.consumer_wait(pipeline_consumer_state); ++pipeline_consumer_state; if (lane_predicate) { copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); } tma_store_arrive(); pipeline.consumer_wait(pipeline_consumer_state); ++pipeline_consumer_state; if (lane_predicate) { copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); } tma_store_arrive(); tma_store_wait<1>(); pipeline.consumer_release(pipeline_release_state); ++pipeline_release_state; tma_store_wait<0>(); if constexpr (cute::is_same_v) { cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } pipeline.consumer_release(pipeline_release_state); ++pipeline_release_state; } }; } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/arch/simd_sm100.hpp" #include "cute/tensor.hpp" #include "cute/layout.hpp" #include "../collective/fmha_common.hpp" #include "../collective/fmha_fusion.hpp" #include "../collective/sm100_fmha_load_tma_warpspecialized.hpp" namespace cutlass::fmha::collective { using namespace cute; template< class Element_, class ElementQK_, class ElementPV_, class TileShape_, class StrideQ_, class StrideK_, class StrideV_, class Mask_, // shape here is QG K H // and referes to the two softmax warps // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) // (1, 2, 1) means they sit side by side (best for small Q / large K) class ThreadShape = Shape<_2, _1, _1>, // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. class OrderLoadEpilogue = cute::false_type > struct Sm100FmhaFwdMainloopTmaWarpspecialized { using Element = Element_; using ElementQK = ElementQK_; using ElementPV = ElementPV_; using TileShape = TileShape_; using StrideQ = StrideQ_; using StrideK = StrideK_; using StrideV = StrideV_; using Mask = Mask_; static constexpr int StageCountQ = 2; static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; using ClusterShape = Shape<_1, _1, _1>; static const int Alignment = 128 / sizeof_bits_v; using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, StrideQ, Alignment, Element, StrideK, Alignment, ElementQK, TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // the stride for A does not matter since we do not load from smem at all Element, StrideK, Alignment, Element, decltype(select<1,0,2>(StrideV{})), Alignment, ElementPV, TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); // Reuse shared memory for V and O. static constexpr bool IsOrderLoadEpilogue = std::is_same_v; struct TensorStorage { cute::array_aligned> smem_q; union { cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; }; enum class TmemAllocation : uint32_t { kSizeS = 128, kSizeO = 128, kSizeP = 32, S0 = 0, S1 = S0 + kSizeS, V0 = S0, // stats storage from softmax to correction V1 = S1, P0 = S0 + kSizeP, P1 = S1 + kSizeP, O0 = S1 + kSizeS, O1 = O0 + kSizeO, kEnd = O1 + kSizeO }; // indices for V0 / V1 enum : int { kIdxOldRowMax = 0, kIdxNewRowMax = 1, kIdxFinalRowSum = 0, kIdxFinalRowMax = 1 }; // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineTmaUmmaAsync< StageCountQ, typename CollectiveMmaQK::AtomThrShapeMNK >; // from load to mma warp, protects k/v in smem using PipelineKV = cutlass::PipelineTmaUmmaAsync< StageCountKV, typename CollectiveMmaQK::AtomThrShapeMNK >; // from mma to softmax0/1 warp, protects S in tmem // (not sure yet about the reverse direction) // there is one pipe per softmax warp, and the mma warp alternates between them using PipelineS = cutlass::PipelineUmmaAsync<1>; // from softmax0/1/ to correction wg using PipelineC = cutlass::PipelineAsync<1>; // from mma to correction using PipelineO = cutlass::PipelineUmmaAsync<2>; // from corr to epilogue using PipelineE = cutlass::PipelineAsync<2>; using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< /*stages*/ 1, /*groups*/ 2>; static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); using Load = Sm100FmhaLoadTmaWarpspecialized< Element, StrideQ, StrideK, StrideV, CollectiveMmaQK, CollectiveMmaPV, SmemLayoutQ, SmemLayoutK, SmemLayoutV, TensorStorage, PipelineQ, PipelineKV, Mask, TileShape >; struct Arguments { typename Load::Arguments load; // if zero, defaults to 1/sqrt(D) float scale_softmax = 0.0f; // scaling factors to dequantize QKV float scale_q = 1.0f; float scale_k = 1.0f; float scale_v = 1.0f; // scaling factor to quantize O float inv_scale_o = 1.0f; }; struct Params { typename Load::Params load; float scale_softmax; float scale_softmax_log2; float scale_output; }; template static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { return true; } template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { float scale_softmax = args.scale_softmax; if (scale_softmax == 0.0f) { scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); } float log2_e = static_cast(std::log2(std::exp(1.0))); return Params{ Load::to_underlying_arguments(problem_shape, args.load, workspace), args.scale_q * args.scale_k * scale_softmax, args.scale_q * args.scale_k * log2_e * scale_softmax, args.scale_v * args.inv_scale_o }; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { Load::prefetch_tma_descriptors(params.load); } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, ProblemShape const& problem_shape, Params const& params, ParamsProblemShape const& params_problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { Load load; load.load(blk_coord, problem_shape, params.load, params_problem_shape, storage, pipeline_q, pipeline_q_producer_state, pipeline_kv, pipeline_kv_producer_state); } template CUTLASS_DEVICE auto mma( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { auto pipeline_q_release_state = pipeline_q_consumer_state; auto pipeline_kv_release_state = pipeline_kv_consumer_state; int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); typename CollectiveMmaQK::TiledMma mma_qk; ThrMMA thr_mma_qk = mma_qk.get_slice(0); typename CollectiveMmaPV::TiledMma mma_pv; TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); Tensor tSrK = thr_mma_qk.make_fragment_B(sK); Tensor tOrV = thr_mma_pv.make_fragment_B(sV); // tmem layout is // S0 S1`O0 O1 // sequential in memory, where S overlaps with P and V Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); Tensor tStS0 = tStS; tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); Tensor tStS1 = tStS; tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); Tensor tOtO0 = tOtO; tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); Tensor tOtO1 = tOtO; tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging Tensor tOrP0 = tOrP; tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); Tensor tOrP1 = tOrP; tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); int k_index = 0; int v_index = 0; int q_index = 0; // wait for Q1 q_index = pipeline_q_consumer_state.index(); pipeline_q.consumer_wait(pipeline_q_consumer_state); ++pipeline_q_consumer_state; Tensor tSrQ0 = tSrQ(_,_,_,q_index); // wait for K1 k_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm Q1 * K1 -> S1 pipeline_s0.producer_acquire(pipeline_s0_producer_state); gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; // release K1 if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } // wait for Q2 if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { q_index = pipeline_q_consumer_state.index(); pipeline_q.consumer_wait(pipeline_q_consumer_state); ++pipeline_q_consumer_state; } Tensor tSrQ1 = tSrQ(_,_,_,q_index); if constexpr (get<1>(ThreadShape{}) > 1) { k_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } pipeline_s1.producer_acquire(pipeline_s1_producer_state); // gemm Q2 * K1 -> S2 gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // release K1 pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; // wait for V1 v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // this acquire returns the ownership of all of S0 to the mma warp // including the P0 part // acquire corr first to take it out of the critical // path since softmax takes longer pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s0.producer_acquire(pipeline_s0_producer_state); // gemm P1 * V1 -> O1 gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; // loop: mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { // wait for Ki k_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm Q1 * Ki -> S1 gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } // gemm P2 * V(i-1) -> O2 if constexpr (get<1>(ThreadShape{}) > 1) { v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; // release V(i-1) pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; if constexpr (get<1>(ThreadShape{}) > 1) { k_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } // gemm Q2 * Ki -> S2 gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // release Ki pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; // wait for Vi v_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm P1 * Vi -> O1 pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s0.producer_acquire(pipeline_s0_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } } // release Q1 pipeline_q.consumer_release(pipeline_q_release_state); ++pipeline_q_release_state; // release Q2 if constexpr (get<0>(ThreadShape{}) > 1) { pipeline_q.consumer_release(pipeline_q_release_state); ++pipeline_q_release_state; } // wait for Vi if constexpr (get<1>(ThreadShape{}) > 1) { v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } // gemm P2 * Vi -> O2 pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; // release Vi pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... } template CUTLASS_DEVICE auto softmax_step( float& row_max, float& row_sum, Stage stage, bool final_call, BlkCoord const& blk_coord, CoordTensor const& cS, Params const& params, ProblemShape const& problem_shape, PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); // wait on tensor core pipe pipeline_s.consumer_wait(pipeline_s_consumer_state); // read all of S from tmem into reg mem Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); if constexpr (need_apply_mask) { Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); } ElementQK old_row_max = row_max; { // compute rowmax float row_max_0 = row_max; float row_max_1 = row_max; float row_max_2 = row_max; float row_max_3 = row_max; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); } row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); } ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); pipeline_c.producer_commit(pipeline_c_producer_state); ++pipeline_c_producer_state; // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) ElementQK scale = params.scale_softmax_log2; ElementQK row_max_scale = row_max_safe * scale; float2 scale_fp32x2 = make_float2(scale, scale); float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); constexpr int kConversionsPerStep = 2; Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; const int kReleasePipeCount = 10; // must be multiple of 2 order_s.wait(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { float2 in = make_float2( tTMEM_LOADrS(i + 0), tTMEM_LOADrS(i + 1) ); float2 out; cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); tTMEM_LOADrS(i + 0) = out.x; tTMEM_LOADrS(i + 1) = out.y; tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); Array in_conv; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < kConversionsPerStep; j++) { in_conv[j] = tTMEM_LOADrS(i + j); } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } // this prevents register spills in fp16 if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { if (i == size(tTMEM_LOADrS) - 6) { copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); } } } // tmem_store(reg_S8) -> op_P CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); cutlass::arch::fence_view_async_tmem_store(); // notify tensor core warp that P is ready pipeline_s.consumer_release(pipeline_s_consumer_state); ++pipeline_s_consumer_state; pipeline_c.producer_acquire(pipeline_c_producer_state); ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); row_sum *= acc_scale; // row_sum = sum(reg_S) float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); float2 local_row_sum_1 = make_float2(0, 0); float2 local_row_sum_2 = make_float2(0, 0); float2 local_row_sum_3 = make_float2(0, 0); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { // row_sum += tTMEM_LOADrS(i); float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); cute::add(local_row_sum_1, local_row_sum_1, in); in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); cute::add(local_row_sum_2, local_row_sum_2, in); in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); cute::add(local_row_sum_3, local_row_sum_3, in); } cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; row_sum = local_row_sum; if (final_call) { // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); } } template CUTLASS_DEVICE auto softmax( Stage stage, BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); ElementQK row_max = -INFINITY; ElementQK row_sum = 0; Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); auto logical_offset = make_coord( get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) ); Tensor cS = domain_offset(logical_offset, cS_base); pipeline_c.producer_acquire(pipeline_c_producer_state); CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { softmax_step( row_max, row_sum, stage, (mask_tile_count == 1) && (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), blk_coord, cS, params, problem_shape, pipeline_s, pipeline_s_consumer_state, pipeline_c, pipeline_c_producer_state, order_s ); cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); } // Masked iterations mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { softmax_step( row_max, row_sum, stage, mask_tile_count == 1, blk_coord, cS, params, problem_shape, pipeline_s, pipeline_s_consumer_state, pipeline_c, pipeline_c_producer_state, order_s ); cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); } pipeline_c.producer_commit(pipeline_c_producer_state); ++pipeline_c_producer_state; pipeline_c.producer_acquire(pipeline_c_producer_state); // empty step to sync against pipe s pipeline_s.consumer_release(pipeline_s_consumer_state); ++pipeline_s_consumer_state; } template CUTLASS_DEVICE auto correction_epilogue( float scale, Stage stage, TensorO const& sO_01) { using ElementOut = typename TensorO::value_type; int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); Tensor sO = sO_01(_,_,stage); // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 const int kCorrectionTileSize = 32 / sizeof(ElementOut); using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOsO = mma.get_slice(0).partition_C(sO); Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); if constexpr (decltype(stage == _0{})::value) { tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); } else { static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); } auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); float2 scale_f32x2 = make_float2(scale, scale); // loop: // TMEM_LOAD, FMUL2 scale, TMEM_STORE CUTLASS_PRAGMA_UNROLL for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); #ifndef ONLY_SOFTMAX CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO); j += 2) { float2 in = make_float2(tTMrO(j), tTMrO(j+1)); float2 out; cute::mul(out, scale_f32x2, in); tTMrO(j) = out.x; tTMrO(j+1) = out.y; } #endif constexpr int N = 4 / sizeof(ElementOut); NumericArrayConverter convert; Tensor tSMrO = make_tensor_like(tTMrO); Tensor tCs = recast(tTMrO); Tensor tCd = recast(tSMrO); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tCs); j++) { tCd(j) = convert.convert(tCs(j)); } Tensor tSMsO_i = recast(tTMEM_LOADsO_i); Tensor tSMrO_i = recast(tSMrO); copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); } cutlass::arch::fence_view_async_shared(); } CUTLASS_DEVICE auto correction_rescale( float scale, uint32_t tmem_O) { int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 const int kCorrectionTileSize = 16; using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); float2 scale_f32x2 = make_float2(scale, scale); Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); }; auto copy_out = [&](int i) { Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); }; // sequence: LLMSLMSLMSS // loop: // TMEM_LOAD, FMUL2 scale, TMEM_STORE copy_in(0); int count = get<2>(TileShape{}) / kCorrectionTileSize; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < count; i++) { if (i != count - 1) { copy_in(i+1); } Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO_i); j += 2) { float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); float2 out; cute::mul(out, scale_f32x2, in); tTMrO_i(j) = out.x; tTMrO_i(j+1) = out.y; } copy_out(i); } } template< class BlkCoord, class ProblemShape, class ParamsProblemShape, class TensorStorageEpi, class CollectiveEpilogue > CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, CollectiveEpilogue& epilogue) { int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); // ignore first signal from softmax as no correction is required pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); // handle the last iteration differently (i.e. tmem_load/stsm for epi) mask_tile_count -= 1; CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); // read row_wise new global max copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); correction_rescale(scale, uint32_t(TmemAllocation::O0)); pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; cutlass::arch::fence_view_async_tmem_store(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); correction_rescale(scale, uint32_t(TmemAllocation::O1)); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; cutlass::arch::fence_view_async_tmem_store(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; } pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; // do the final correction to O1 // better to somehow special-case it in the loop above // doesn't matter for non-persistent code, but if it were // persistent we do not want to release O too early pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); // read from V0 // read row_sum and final row_max here Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; pipeline_o.consumer_wait(pipeline_o_consumer_state); pipeline_epi.producer_acquire(pipeline_epi_producer_state); // store to epi smem // loop: // TMEM_LOAD // FMUL2 scale = 1 / global_sum * out_quant_scale // F2FP // store to smem Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); // load from V1 copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; pipeline_o.consumer_wait(pipeline_o_consumer_state); pipeline_epi.producer_acquire(pipeline_epi_producer_state); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; } template< class BlkCoord, class ProblemShape, class ParamsProblemShape, class TensorStorageEpi, class CollectiveEpilogue > CUTLASS_DEVICE auto correction_empty( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi, PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, CollectiveEpilogue& epilogue) { pipeline_epi.producer_acquire(pipeline_epi_producer_state); Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; auto tiled_copy = make_cotiled_copy( Copy_Atom, ElementOut>{}, make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), sO.layout()); auto thr_copy = tiled_copy.get_slice(thread_idx); auto tOgO = thr_copy.partition_D(sO); auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); clear(tOrO); copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); #endif if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); cutlass::arch::fence_view_async_shared(); pipeline_epi.producer_acquire(pipeline_epi_producer_state); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_shared(); pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; } }; } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/tensor.hpp" #include "cute/layout.hpp" #include "../collective/fmha_common.hpp" #include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { using namespace cute; template< class Element, class StrideQ, class StrideK, class StrideV, class CollectiveMmaQK, class CollectiveMmaPV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutV, class TensorStorage, class PipelineQ, class PipelineKV, class Mask, class TileShape > struct Sm100FmhaLoadTmaWarpspecialized { using TileShapeQK = typename CollectiveMmaQK::TileShape; using TileShapePV = typename CollectiveMmaPV::TileShape; struct Arguments { const Element* ptr_Q; StrideQ dQ; const Element* ptr_K; StrideK dK; const Element* ptr_V; StrideV dV; }; using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; using TMA_K = typename CollectiveMmaQK::Params::TMA_B; using TMA_V = typename CollectiveMmaPV::Params::TMA_B; struct Params { TMA_Q tma_load_q; TMA_K tma_load_k; TMA_V tma_load_v; }; template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { auto ptr_Q = args.ptr_Q; auto ptr_K = args.ptr_K; auto ptr_V = args.ptr_V; auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; using IntProblemShape = cute::tuple, int>>; IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; auto cumulative_length_k = get<1>(problem_shape).cumulative_length; if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; get<2>(problem_shape_qk) = get<2>(problem_shape); get<3>(problem_shape_qk) = get<3>(problem_shape); } } else { problem_shape_qk = problem_shape; } get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); auto params_qk = CollectiveMmaQK::to_underlying_arguments( problem_shape_qk, typename CollectiveMmaQK::Arguments { ptr_Q, dQ, ptr_K, dK, }, /*workspace=*/ nullptr); auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); auto params_pv = CollectiveMmaPV::to_underlying_arguments( problem_shape_pv, typename CollectiveMmaPV::Arguments { ptr_K, dK, // never used, dummy ptr_V, select<1,0,2>(dV), }, /*workspace=*/ nullptr); return Params{ params_qk.tma_load_a, params_qk.tma_load_b, params_pv.tma_load_b }; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, Params const& params, ParamsProblemShape const& params_problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { BlkCoord blk_coord_q = blk_coord_in; BlkCoord blk_coord_kv = blk_coord_in; int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); using X = Underscore; // this one is only executed by one thread, no need to elect_one // Q1, K1, Q2, V1, K2, V2, K3, V3, ... // two pipes: Q and KV // from Memory (prod) to TensorCore (cons) // compute gQ, sQ // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); int q_offs_0 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); auto [tQgQ_qdl, tQsQ] = tma_partition( params.tma_load_q, _0{}, make_layout(_1{}), group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) ); Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); // compute gK, sK Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); int kv_offs_0 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); auto [tKgK_kdl, tKsK] = tma_partition( params.tma_load_k, _0{}, make_layout(_1{}), group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) ); Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); // compute gV, sV ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); auto [tVgV_dkl, tVsV] = tma_partition( params.tma_load_v, _0{}, make_layout(_1{}), group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) ); auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); // blk_coord in decomposed in terms of TileShape, not TileShapeQK // As such, it needs to be transformed as // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) uint32_t lane_predicate = cute::elect_one_sync(); // Q1 int q0_index = 2 * get<0>(blk_coord_q); int q1_index = 2 * get<0>(blk_coord_q) + 1; pipeline_q.producer_acquire(pipeline_q_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); } ++pipeline_q_producer_state; // K1 int k_index = 0; pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); } ++pipeline_kv_producer_state; // Q2 pipeline_q.producer_acquire(pipeline_q_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); } ++pipeline_q_producer_state; // V1 pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); } ++pipeline_kv_producer_state; k_index += 1; // loop: mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { // Ki pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); } ++pipeline_kv_producer_state; // Vi pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); } ++pipeline_kv_producer_state; k_index += 1; } } }; } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/arch/simd_sm100.hpp" #include "cute/tensor.hpp" #include "cute/layout.hpp" #include "../collective/fmha_common.hpp" #include "../collective/fmha_fusion.hpp" #include "../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" #include "../common/pipeline_mla.hpp" namespace cutlass::fmha::collective { using namespace cute; template< class Element_, class ElementQK_, class ElementPV_, class ComposedTileShape_, class StrideQ_, class StrideK_, class StrideV_, class Mask_, // shape here is QG K H // and referes to the two softmax warps // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) // (1, 2, 1) means they sit side by side (best for small Q / large K) class ThreadShape = Shape<_2, _1, _1>, class OrderLoadEpilogue = cute::false_type > struct Sm100MlaFwdMainloopTmaWarpspecialized { using Element = Element_; using ElementQK = ElementQK_; using ElementPV = ElementPV_; using ComposedTileShape = ComposedTileShape_; using StrideQ = StrideQ_; using StrideK = StrideK_; using StrideV = StrideV_; using Mask = Mask_; static constexpr int StageCountQ = 2; static constexpr int StageCountK = 1; static constexpr int StageCountV = 1; static constexpr int StageCountKV = StageCountK + StageCountV; // Support StageCountKV > 2 in the future. static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); using ClusterShape = Shape<_1, _1, _1>; static const int Alignment = 128 / sizeof_bits_v; static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; static constexpr auto HeadDimPV = HeadDimLatent; using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, StrideQ, Alignment, Element, StrideK, Alignment, ElementQK, TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // the stride for A does not matter since we do not load from smem at all Element, StrideK, Alignment, Element, decltype(select<1,0,2>(StrideV{})), Alignment, ElementPV, TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, // we reuse shared memory for V and O to address this problem, // and a barrier has been added to coordinate access to shared memory. static constexpr bool IsOrderLoadEpilogue = std::is_same_v; static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; struct TensorStorageQKVO { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_o; // use as O0 cute::array_aligned> smem_v; // use as V0 and O1 }; struct TensorStorageQKV { cute::array_aligned> smem_q; cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; using TensorStorage = std::conditional_t; enum class TmemAllocation : uint32_t { kSizeS = 128, kSizeO = 128, kSizeP = 32, S0 = 0, S1 = S0 + kSizeS, V0 = S0, // stats storage from softmax to correction V1 = S1, P0 = S0 + kSizeP, P1 = S1 + kSizeP, O0 = S1 + kSizeS, O1 = O0 + kSizeO, kEnd = O1 + kSizeO }; // indices for V0 / V1 enum : int { kIdxOldRowMax = 0, kIdxNewRowMax = 1, kIdxFinalRowSum = 0, kIdxFinalRowMax = 1 }; // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineTmaUmmaAsync< StageCountQ, typename CollectiveMmaQK::AtomThrShapeMNK >; // from load to mma warp, protects k/v in smem using PipelineKV = cutlass::PipelineTmaAsyncMla< StageCountKV, typename CollectiveMmaQK::AtomThrShapeMNK >; // from mma to softmax0/1 warp, protects S in tmem // (not sure yet about the reverse direction) // there is one pipe per softmax warp, and the mma warp alternates between them using PipelineS = cutlass::PipelineUmmaAsync<1>; // from softmax0/1/ to correction wg using PipelineC = cutlass::PipelineAsync<1>; // from mma to correction using PipelineO = cutlass::PipelineUmmaAsync<2>; // from corr to epilogue using PipelineE = cutlass::PipelineAsync<2>; using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< /*stages*/ 1, /*groups*/ 2>; static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); using Load = Sm100MlaFwdLoadTmaWarpspecialized< Element, StrideQ, StrideK, StrideV, CollectiveMmaQK, CollectiveMmaPV, SmemLayoutQ, SmemLayoutK, SmemLayoutV, TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue >; struct Arguments { typename Load::Arguments load; // if zero, defaults to 1/sqrt(D) float scale_softmax = 0.0f; // scaling factors to dequantize QKV float scale_q = 1.0f; float scale_k = 1.0f; float scale_v = 1.0f; // scaling factor to quantize O float inv_scale_o = 1.0f; }; struct Params { typename Load::Params load; float scale_softmax; float scale_softmax_log2; float scale_output; }; template static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { return true; } template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { float scale_softmax = args.scale_softmax; if (scale_softmax == 0.0f) { scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); } float log2_e = static_cast(std::log2(std::exp(1.0))); return Params{ Load::to_underlying_arguments(problem_shape, args.load, workspace), args.scale_q * args.scale_k * scale_softmax, args.scale_q * args.scale_k * log2_e * scale_softmax, args.scale_v * args.inv_scale_o }; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { Load::prefetch_tma_descriptors(params.load); } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, ProblemShape const& problem_shape, Params const& params, ParamsProblemShape const& params_problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { Load load; load.load(blk_coord, problem_shape, params.load, params_problem_shape, storage, pipeline_q, pipeline_q_producer_state, pipeline_kv, pipeline_kv_producer_state); } template CUTLASS_DEVICE auto mma( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { auto pipeline_q_release_state = pipeline_q_consumer_state; auto pipeline_kv_release_state = pipeline_kv_consumer_state; int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); typename CollectiveMmaQK::TiledMma mma_qk; ThrMMA thr_mma_qk = mma_qk.get_slice(0); typename CollectiveMmaPV::TiledMma mma_pv; TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); Tensor tSrK = thr_mma_qk.make_fragment_B(sK); Tensor tOrV = thr_mma_pv.make_fragment_B(sV); // tmem layout is // S0 S1`O0 O1 // sequential in memory, where S overlaps with P and V Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); Tensor tStS0 = tStS; tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); Tensor tStS1 = tStS; tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); Tensor tOtO0 = tOtO; tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); Tensor tOtO1 = tOtO; tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging Tensor tOrP0 = tOrP; tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); Tensor tOrP1 = tOrP; tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); int k_index = 0; int v_index = 0; int q_index = 0; // wait for Q1 q_index = pipeline_q_consumer_state.index(); pipeline_q.consumer_wait(pipeline_q_consumer_state); ++pipeline_q_consumer_state; Tensor tSrQ0 = tSrQ(_,_,_,q_index); // wait for K1 k_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm Q1 * K1 -> S1 pipeline_s0.producer_acquire(pipeline_s0_producer_state); gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; // release K1 if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } // wait for Q2 if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { q_index = pipeline_q_consumer_state.index(); pipeline_q.consumer_wait(pipeline_q_consumer_state); ++pipeline_q_consumer_state; } Tensor tSrQ1 = tSrQ(_,_,_,q_index); if constexpr (get<1>(ThreadShape{}) > 1) { k_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } pipeline_s1.producer_acquire(pipeline_s1_producer_state); // gemm Q2 * K1 -> S2 gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // release K1 pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; // wait for V1 v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // this acquire returns the ownership of all of S0 to the mma warp // including the P0 part // acquire corr first to take it out of the critical // path since softmax takes longer pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s0.producer_acquire(pipeline_s0_producer_state); // gemm P1 * V1 -> O1 gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; // loop: mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { // wait for Ki k_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm Q1 * Ki -> S1 gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } // gemm P2 * V(i-1) -> O2 if constexpr (get<1>(ThreadShape{}) > 1) { v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; // release V(i-1) pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; if constexpr (get<1>(ThreadShape{}) > 1) { k_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } // gemm Q2 * Ki -> S2 gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // release Ki pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; // wait for Vi v_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm P1 * Vi -> O1 pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s0.producer_acquire(pipeline_s0_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; if constexpr (get<1>(ThreadShape{}) > 1) { pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; } } // release Q1 pipeline_q.consumer_release(pipeline_q_release_state); ++pipeline_q_release_state; // release Q2 if constexpr (get<0>(ThreadShape{}) > 1) { pipeline_q.consumer_release(pipeline_q_release_state); ++pipeline_q_release_state; } // wait for Vi if constexpr (get<1>(ThreadShape{}) > 1) { v_index = pipeline_kv_consumer_state.index(); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; } // gemm P2 * Vi -> O2 pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; // release Vi pipeline_kv.consumer_release(pipeline_kv_release_state); ++pipeline_kv_release_state; pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... } template CUTLASS_DEVICE auto softmax_step( bool need_apply_mask, float& row_max, float& row_sum, Stage stage, bool final_call, BlkCoord const& blk_coord, CoordTensor const& cS, Params const& params, ProblemShape const& problem_shape, PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); // wait on tensor core pipe pipeline_s.consumer_wait(pipeline_s_consumer_state); // read all of S from tmem into reg mem Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); if constexpr (need_mask) { if(need_apply_mask) { Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); } } ElementQK old_row_max = row_max; { // compute rowmax float row_max_0 = row_max; float row_max_1 = row_max; float row_max_2 = row_max; float row_max_3 = row_max; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); } row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); } ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); pipeline_c.producer_commit(pipeline_c_producer_state); ++pipeline_c_producer_state; // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) ElementQK scale = params.scale_softmax_log2; ElementQK row_max_scale = row_max_safe * scale; float2 scale_fp32x2 = make_float2(scale, scale); float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); constexpr int kConversionsPerStep = 2; Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; constexpr int kReleasePipeCount = 10; // must be multiple of 2 order_s.wait(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { float2 in = make_float2( tTMEM_LOADrS(i + 0), tTMEM_LOADrS(i + 1) ); float2 out; cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); tTMEM_LOADrS(i + 0) = out.x; tTMEM_LOADrS(i + 1) = out.y; tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); Array in_conv; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < kConversionsPerStep; j++) { in_conv[j] = tTMEM_LOADrS(i + j); } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } // this prevents register spills in fp16 if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { if (i == size(tTMEM_LOADrS) - 6) { copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); } } } // tmem_store(reg_S8) -> op_P CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); cutlass::arch::fence_view_async_tmem_store(); // notify tensor core warp that P is ready pipeline_s.consumer_release(pipeline_s_consumer_state); ++pipeline_s_consumer_state; pipeline_c.producer_acquire(pipeline_c_producer_state); ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); row_sum *= acc_scale; // row_sum = sum(reg_S) float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); float2 local_row_sum_1 = make_float2(0, 0); float2 local_row_sum_2 = make_float2(0, 0); float2 local_row_sum_3 = make_float2(0, 0); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { // row_sum += tTMEM_LOADrS(i); float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); cute::add(local_row_sum_1, local_row_sum_1, in); in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); cute::add(local_row_sum_2, local_row_sum_2, in); in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); cute::add(local_row_sum_3, local_row_sum_3, in); } cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; row_sum = local_row_sum; if (final_call) { // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); } } template CUTLASS_DEVICE auto softmax( Stage stage, BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); int trip_idx = total_trip_count; ElementQK row_max = -INFINITY; ElementQK row_sum = 0; Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); auto logical_offset = make_coord( get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) ); Tensor cS = domain_offset(logical_offset, cS_base); pipeline_c.producer_acquire(pipeline_c_producer_state); constexpr bool NeedMask = !std::is_same_v; CUTLASS_PRAGMA_NO_UNROLL for (; trip_idx > 0; trip_idx -= 1) { softmax_step( trip_idx <= mask_trip_count, row_max, row_sum, stage, trip_idx == 1, blk_coord, cS, params, problem_shape, pipeline_s, pipeline_s_consumer_state, pipeline_c, pipeline_c_producer_state, order_s ); cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); } pipeline_c.producer_commit(pipeline_c_producer_state); ++pipeline_c_producer_state; pipeline_c.producer_acquire(pipeline_c_producer_state); // empty step to sync against pipe s pipeline_s.consumer_release(pipeline_s_consumer_state); ++pipeline_s_consumer_state; } template CUTLASS_DEVICE auto correction_epilogue( float scale, Stage stage, TensorO const& sO_01) { using ElementOut = typename TensorO::value_type; int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); Tensor sO = sO_01(_,_,stage); // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOsO = mma.get_slice(0).partition_C(sO); Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); if constexpr (decltype(stage == _0{})::value) { tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); } else { static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); } auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); float2 scale_f32x2 = make_float2(scale, scale); // loop: // TMEM_LOAD, FMUL2 scale, TMEM_STORE CUTLASS_PRAGMA_UNROLL for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); #ifndef ONLY_SOFTMAX CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO); j += 2) { float2 in = make_float2(tTMrO(j), tTMrO(j+1)); float2 out; cute::mul(out, scale_f32x2, in); tTMrO(j) = out.x; tTMrO(j+1) = out.y; } #endif constexpr int N = 4 / sizeof(ElementOut); NumericArrayConverter convert; Tensor tSMrO = make_tensor_like(tTMrO); Tensor tCs = recast(tTMrO); Tensor tCd = recast(tSMrO); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tCs); j++) { tCd(j) = convert.convert(tCs(j)); } Tensor tSMsO_i = recast(tTMEM_LOADsO_i); Tensor tSMrO_i = recast(tSMrO); copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); } cutlass::arch::fence_view_async_shared(); } CUTLASS_DEVICE auto correction_rescale( float scale, uint32_t tmem_O) { int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 constexpr int kCorrectionTileSize = 16; using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); float2 scale_f32x2 = make_float2(scale, scale); Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); }; auto copy_out = [&](int i) { Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); }; // sequence: LLMSLMSLMSS // loop: // TMEM_LOAD, FMUL2 scale, TMEM_STORE copy_in(0); constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < count; i++) { if (i != count - 1) { copy_in(i+1); } Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO_i); j += 2) { float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); float2 out; cute::mul(out, scale_f32x2, in); tTMrO_i(j) = out.x; tTMrO_i(j+1) = out.y; } copy_out(i); } } template< class BlkCoord, class ProblemShape, class ParamsProblemShape, class TensorStorageEpi, class CollectiveEpilogue > CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, CollectiveEpilogue& epilogue) { int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); // ignore first signal from softmax as no correction is required pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); // handle the last iteration differently (i.e. tmem_load/stsm for epi) mask_tile_count -= 1; CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); // read row_wise new global max copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); correction_rescale(scale, uint32_t(TmemAllocation::O0)); pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; cutlass::arch::fence_view_async_tmem_store(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); correction_rescale(scale, uint32_t(TmemAllocation::O1)); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; cutlass::arch::fence_view_async_tmem_store(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; } pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; // do the final correction to O1 // better to somehow special-case it in the loop above // doesn't matter for non-persistent code, but if it were // persistent we do not want to release O too early pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); // read from V0 // read row_sum and final row_max here Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; pipeline_o.consumer_wait(pipeline_o_consumer_state); pipeline_epi.producer_acquire(pipeline_epi_producer_state); // store to epi smem // loop: // TMEM_LOAD // FMUL2 scale = 1 / global_sum * out_quant_scale // F2FP // store to smem Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); // load from V1 copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; pipeline_o.consumer_wait(pipeline_o_consumer_state); pipeline_epi.producer_acquire(pipeline_epi_producer_state); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; } template< class BlkCoord, class ProblemShape, class ParamsProblemShape, class TensorStorageEpi, class CollectiveEpilogue > CUTLASS_DEVICE auto correction_empty( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi, PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, CollectiveEpilogue& epilogue) { pipeline_epi.producer_acquire(pipeline_epi_producer_state); Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; auto tiled_copy = make_cotiled_copy( Copy_Atom, ElementOut>{}, make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), sO.layout()); auto thr_copy = tiled_copy.get_slice(thread_idx); auto tOgO = thr_copy.partition_D(sO); auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); clear(tOrO); copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); #endif if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); cutlass::arch::fence_view_async_shared(); pipeline_epi.producer_acquire(pipeline_epi_producer_state); if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); int row_offset = 0; if constexpr (is_variable_length_v>) { row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; } if (row_idx < get<0>(problem_shape)) { gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } cutlass::arch::fence_view_async_shared(); pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; } }; } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/tensor.hpp" #include "cute/layout.hpp" #include "../collective/fmha_common.hpp" #include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { using namespace cute; template< class Element, class StrideQ, class StrideK, class StrideV, class CollectiveMmaQK, class CollectiveMmaPV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutV, class TensorStorage, class PipelineQ, class PipelineKV, class Mask, class TileShape, class OrderLoadEpilogue = cute::false_type > struct Sm100MlaFwdLoadTmaWarpspecialized { using TileShapeQK = typename CollectiveMmaQK::TileShape; using TileShapePV = typename CollectiveMmaPV::TileShape; static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; struct Arguments { const Element* ptr_Q; StrideQ dQ; const Element* ptr_K; StrideK dK; const Element* ptr_V; StrideV dV; }; using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; using TMA_K = typename CollectiveMmaQK::Params::TMA_B; using TMA_V = typename CollectiveMmaPV::Params::TMA_B; struct Params { TMA_Q tma_load_q; TMA_K tma_load_k; TMA_V tma_load_v; }; template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { auto ptr_Q = args.ptr_Q; auto ptr_K = args.ptr_K; auto ptr_V = args.ptr_V; auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; using IntProblemShape = cute::tuple, int>>; IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; auto cumulative_length_k = get<1>(problem_shape).cumulative_length; if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); get<3>(problem_shape_qk) = get<3>(problem_shape); } } else { problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); auto params_qk = CollectiveMmaQK::to_underlying_arguments( problem_shape_qk, typename CollectiveMmaQK::Arguments { ptr_Q, dQ, ptr_K, dK, }, /*workspace=*/ nullptr); auto params_pv = CollectiveMmaPV::to_underlying_arguments( problem_shape_pv, typename CollectiveMmaPV::Arguments { ptr_K, dK, // never used, dummy ptr_V, select<1,0,2>(dV), }, /*workspace=*/ nullptr); return Params{ params_qk.tma_load_a, params_qk.tma_load_b, params_pv.tma_load_b }; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, Params const& params, ParamsProblemShape const& params_problem_shape, TensorStorage& storage, PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { BlkCoord blk_coord_q = blk_coord_in; BlkCoord blk_coord_kv = blk_coord_in; auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); using X = Underscore; // this one is only executed by one thread, no need to elect_one // Q1, K1, Q2, V1, K2, V2, K3, V3, ... // two pipes: Q and KV // from Memory (prod) to TensorCore (cons) // compute gQ, sQ // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); int q_offs_0 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); auto [tQgQ_qdl, tQsQ] = tma_partition( params.tma_load_q, _0{}, make_layout(_1{}), group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) ); Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); // compute gK, sK Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); int kv_offs_0 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); auto [tKgK_kdl, tKsK] = tma_partition( params.tma_load_k, _0{}, make_layout(_1{}), group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) ); Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); // compute gV, sV ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); auto [tVgV_dkl, tVsV] = tma_partition( params.tma_load_v, _0{}, make_layout(_1{}), group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) ); auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); // blk_coord in decomposed in terms of TileShape, not TileShapeQK // As such, it needs to be transformed as // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) uint32_t lane_predicate = cute::elect_one_sync(); // Q1 int q0_index = 2 * get<0>(blk_coord_q); int q1_index = 2 * get<0>(blk_coord_q) + 1; pipeline_q.producer_acquire(pipeline_q_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); } ++pipeline_q_producer_state; // K1 int k_index = 0; pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); } ++pipeline_kv_producer_state; // Q2 pipeline_q.producer_acquire(pipeline_q_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); } ++pipeline_q_producer_state; if constexpr (cute::is_same_v) { cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } // V1 pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); } ++pipeline_kv_producer_state; k_index += 1; // loop: mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { // Ki pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); // prefetch vi cute::prefetch(params.tma_load_v, tVgV(_, k_index)); } ++pipeline_kv_producer_state; // Vi pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); // prefetch ki+1 if(mask_tile_count > 1) { cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); } } ++pipeline_kv_producer_state; k_index += 1; } } }; } // namespace cutlass::fmha::collective ================================================ FILE: csrc/sm100/prefill/dense/common/gather_tensor.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cute/layout.hpp" #include "cute/tensor.hpp" #include "cute/util/print.hpp" namespace example { using namespace cute; // Empty type used to disable gather/scatter for a GEMM argument struct NoGather { template NoGather(Ts...) {}; }; /// Function object that applies an index to its argument template struct IndexedGather { CUTE_HOST_DEVICE constexpr IndexedGather(Index const *indices = {}): indices_(indices) {} template CUTE_HOST_DEVICE constexpr Index operator()(I i) const { return indices_[i]; } CUTE_HOST_DEVICE friend void print(IndexedGather const &s) { cute::print("Indexed"); } Index const *indices_; }; /// Function object that applies a stride to its argument /// Example: StridedFunc gathers every other row/column template struct StridedGather { CUTE_HOST_DEVICE constexpr StridedGather(Stride stride = {}): stride_(stride) {} template CUTE_HOST_DEVICE constexpr auto operator()(I i) const { return i * stride_; } CUTE_HOST_DEVICE friend void print(StridedGather const &s) { cute::print("Strided{"); print(s.stride_); cute::print("}"); } Stride stride_; }; /// Custom stride object that applies a function followed by a stride template struct CustomStride { CUTE_HOST_DEVICE constexpr CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} template CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } template CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } CUTE_HOST_DEVICE friend void print(CustomStride const & s) { cute::print("Custom{"); print(s.func_); cute::print(","); print(s.stride_); cute::print("}"); } template CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const &s, Div const &div) { return CustomStride(s.func_, safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral template CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const &shape, CustomStride const &stride) { return Layout(shape, stride); } Func func_; Stride stride_; }; template CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const &stride, Func&& func) { // Use a dummy shape and replace the first non-unit stride with a custom gather stride auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); constexpr int I = decltype(idx)::value; return make_layout(repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); } /// Helper function to optionally create a gather tensor template CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) { if constexpr (not cutlass::platform::is_same, NoGather>::value) { Layout matrix_layout = make_identity_layout(shape); auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } else { return make_tensor(iter, shape, stride); } } } // namespace example namespace cute { template CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) { if constexpr (is_tuple::value) { return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); } else if constexpr (is_scaled_basis::value) { if constexpr (Stride::mode() == I) { return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); } else { return make_layout(shape, stride); } } else { return upcast(shape, stride); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto upcast(ComposedLayout,Offset,Layout> const& layout) { // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); constexpr int I = decltype(idx)::value; // Upcast the outer layout (works as expected) auto outer = upcast(layout.layout_a()); // Upcast the accumulated offset along stride-1 mode auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); // Upcast the inner layout's shape along stride-1 mode auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); return composition(outer, offset, inner); } } // namespace example ================================================ FILE: csrc/sm100/prefill/dense/common/helper.h ================================================ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cuda_runtime.h" #include /** * Panic wrapper for unwinding CUTLASS errors */ #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ if (error != cutlass::Status::kSuccess) { \ std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ << std::endl; \ exit(EXIT_FAILURE); \ } \ } /** * Panic wrapper for unwinding CUDA runtime errors */ #define CUDA_CHECK(status) \ { \ cudaError_t error = status; \ if (error != cudaSuccess) { \ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ << " at line: " << __LINE__ << std::endl; \ exit(EXIT_FAILURE); \ } \ } #define FLASH_MLA_ASSERT(cond) \ do { \ if (!(cond)) { \ std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \ std::abort(); \ } \ } while (0) ================================================ FILE: csrc/sm100/prefill/dense/common/mask.cuh ================================================ #pragma once enum class MaskMode { kNone = 0U, // No mask kCausal = 1U, // Causal mask kCustom = 2U, // Custom mask }; ================================================ FILE: csrc/sm100/prefill/dense/common/pipeline_mla.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Support the producer to acquire specific bytes of data. */ #pragma once #include "cutlass/pipeline/sm100_pipeline.hpp" namespace cutlass { using namespace cute; template < int Stages_, class ClusterShape = Shape, class AtomThrShape_MNK_ = Shape<_1,_1,_1> > class PipelineTmaAsyncMla { public: static constexpr uint32_t Stages = Stages_; using AtomThrShape_MNK = AtomThrShape_MNK_; private: using Impl = PipelineTmaUmmaAsync; public: using FullBarrier = typename Impl::FullBarrier; using EmptyBarrier = typename Impl::EmptyBarrier; using ProducerBarrierType = typename Impl::ProducerBarrierType; using ConsumerBarrierType = typename Impl::ConsumerBarrierType; using PipelineState = typename Impl::PipelineState; using SharedStorage = typename Impl::SharedStorage; using ThreadCategory = typename Impl::ThreadCategory; using Params = typename Impl::Params; using McastDirection = McastDirection; // Helper function to initialize barriers static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; auto atom_thr_shape = AtomThrShape_MNK{}; uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } cutlass::arch::fence_barrier_init(); } static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { auto atom_thr_shape = AtomThrShape_MNK{}; int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } cutlass::arch::fence_barrier_init(); } CUTLASS_DEVICE void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { // Calculate consumer mask if (params_.role == ThreadCategory::Consumer) { auto cluster_layout = make_layout(cluster_shape); block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); } } CUTLASS_DEVICE void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { // Calculate consumer mask dim3 block_id_in_cluster = cute::block_id_in_cluster(); auto cluster_layout = make_layout(cluster_shape); if (mcast_direction == McastDirection::kRow) { block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); } else { block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); } } public: template CUTLASS_DEVICE PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) , params_(params) , empty_barrier_ptr_(&storage.empty_barrier_[0]) , full_barrier_ptr_(&storage.full_barrier_[0]) { static_assert(cute::is_same_v || cute::is_same_v); if constexpr (cute::is_same_v) { init_barriers(storage, params_, cluster_shape); } static_assert(cute::is_same_v || cute::is_same_v); if constexpr (cute::is_same_v) { init_masks(cluster_shape); } } template CUTLASS_DEVICE PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) , params_(params) , empty_barrier_ptr_(&storage.empty_barrier_[0]) , full_barrier_ptr_(&storage.full_barrier_[0]) { static_assert(cute::is_same_v || cute::is_same_v); if constexpr (cute::is_same_v) { init_barriers(storage, params_, cluster_shape, mcast_direction); } static_assert(cute::is_same_v || cute::is_same_v); if constexpr (cute::is_same_v) { init_masks(cluster_shape, mcast_direction); } } CUTLASS_DEVICE void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { impl_.producer_acquire(state, barrier_token); } CUTLASS_DEVICE void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { detail::pipeline_check_is_producer(params_.role); if (barrier_token != BarrierStatus::WaitDone) { empty_barrier_ptr_[stage].wait(phase); } if (params_.is_leader) { full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); } #ifndef NDEBUG if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { asm volatile ("brkpt;\n" ::); } // Most likely you have elected more than one leader if (params_.is_leader && (threadIdx.x % 32 != 0)) { asm volatile ("brkpt;\n" ::); } #endif } CUTLASS_DEVICE void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); } CUTLASS_DEVICE ProducerBarrierType* producer_get_barrier(PipelineState state) { return impl_.producer_get_barrier(state); } CUTLASS_DEVICE void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { impl_.consumer_wait(state, barrier_token); } CUTLASS_DEVICE void consumer_release(PipelineState state) { consumer_release(state.index(), false); } private: Impl impl_; Params params_; EmptyBarrier *empty_barrier_ptr_; FullBarrier *full_barrier_ptr_; uint16_t block_id_mask_ = 0; static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; // Consumer signalling Producer of completion // Ensures all blocks in the Same Row and Column get notifed. CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip) { detail::pipeline_check_is_consumer(params_.role); uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 if (!skip) { cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); } } else { if (!skip) { if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { cutlass::arch::umma_arrive(smem_ptr); } else { cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); } } } } }; } ================================================ FILE: csrc/sm100/prefill/dense/common/pow_2.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include namespace cutlass::fmha { struct Pow2 { int n; int log2_n; explicit CUTE_DEVICE Pow2(int n) : n(n) { #ifdef __CUDA_ARCH__ log2_n = __ffs(n) - 1; #endif } template CUTE_HOST_DEVICE T operator *(T const& b) const { return n * b; } template CUTE_HOST_DEVICE auto operator *(Int const&) const { if constexpr (N & (N - 1) == 0) { return Pow2{n * N}; } return n * N; } }; template CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { return a >> b.log2_n; } template CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { return a & (b.n - 1); } template CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { return a < b.n; } CUTE_HOST_DEVICE void print(Pow2 const& a) { printf("2^%d", a.log2_n); } } // end namespace cutlass::fmha namespace cute { template <> struct is_integral : true_type {}; } // end namespace cute ================================================ FILE: csrc/sm100/prefill/dense/common/utils.hpp ================================================ #pragma once #include #include "cutlass/numeric_types.h" #include "helper.h" template struct cutlass_dtype { using type = T; }; template <> struct cutlass_dtype { using type = cutlass::half_t; }; template <> struct cutlass_dtype { using type = cutlass::bfloat16_t; }; template <> struct cutlass_dtype<__nv_fp8_e4m3> { using type = cutlass::float_e4m3_t; }; template <> struct cutlass_dtype<__nv_fp8_e5m2> { using type = cutlass::float_e5m2_t; }; template using cutlass_dtype_t = typename cutlass_dtype::type; ================================================ FILE: csrc/sm100/prefill/dense/device/fmha.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief An universal device layer for cutlass 3.x-style kernels. */ #pragma once // common #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" #include "cutlass/trace.h" #endif // !defined(__CUDACC_RTC__) //////////////////////////////////////////////////////////////////////////////// namespace cutlass::fmha::device { //////////////////////////////////////////////////////////////////////////////// ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// template class FMHA { public: using Kernel = Kernel_; static int const kThreadCount = Kernel::MaxThreadsPerBlock; /// Argument structure: User API using Arguments = typename Kernel::Arguments; /// Argument structure: Kernel API using Params = typename Kernel::Params; private: /// Kernel API parameters object Params params_; bool is_initialized(bool set = false) { static bool initialized = false; if (set) initialized = true; return initialized; } public: /// Access the Params structure Params const& params() const { return params_; } /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const& args) { if (Kernel::can_implement(args)) { return Status::kSuccess; } else { return Status::kInvalid; } } /// Gets the workspace size static size_t get_workspace_size(Arguments const& args) { size_t workspace_bytes = 0; workspace_bytes += Kernel::get_workspace_size(args); return workspace_bytes; } /// Computes the grid shape static dim3 get_grid_shape(Params const& params) { return Kernel::get_grid_shape(params); } /// Computes the maximum number of active blocks per multiprocessor static int maximum_active_blocks(int /* smem_capacity */ = -1) { CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); int max_active_blocks = -1; int smem_size = Kernel::SharedStorageSize; // first, account for dynamic smem capacity if needed cudaError_t result; if (smem_size >= (48 << 10)) { CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); result = cudaFuncSetAttribute( device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST( " cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); return -1; } } // query occupancy after setting smem size result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, device_kernel, Kernel::MaxThreadsPerBlock, smem_size); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST( " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " << cudaGetErrorString(result)); return -1; } CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); return max_active_blocks; } /// Initializes GEMM state from arguments. Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); // Initialize the workspace Status status = Kernel::initialize_workspace(args, workspace, stream); if (status != Status::kSuccess) { return status; } // Initialize the Params structure params_ = Kernel::to_underlying_arguments(args, workspace); if (is_initialized()) return Status::kSuccess; // account for dynamic smem capacity if needed int smem_size = Kernel::SharedStorageSize; if (smem_size >= (48 << 10)) { CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); cudaError_t result = cudaFuncSetAttribute( device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); return Status::kErrorInternal; } } is_initialized(true); return Status::kSuccess; } /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. Status update(Arguments const& args, void* workspace = nullptr) { CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); size_t workspace_bytes = get_workspace_size(args); if (workspace_bytes > 0 && nullptr == workspace) { return Status::kErrorWorkspaceNull; } params_ = Kernel::to_underlying_arguments(args, workspace); return Status::kSuccess; } /// Primary run() entry point API that is static allowing users to create and manage their own params. /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() static Status run(Params& params, cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("FMHA::run()"); dim3 const block = Kernel::get_block_shape(); dim3 const grid = get_grid_shape(params); // No need to launch the kernel if(grid.x == 0 || grid.y == 0 || grid.z == 0) { return Status::kSuccess; } // configure smem size and carveout int smem_size = Kernel::SharedStorageSize; Status launch_result; // Use extended launch API only for mainloops that use it if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), cute::size<1>(typename Kernel::ClusterShape{}), cute::size<2>(typename Kernel::ClusterShape{})); void const* kernel = (void const*) device_kernel; void* kernel_params[] = {¶ms}; launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); } else { launch_result = Status::kSuccess; device_kernel<<>>(params); } cudaError_t result = cudaGetLastError(); if (cudaSuccess == result && Status::kSuccess == launch_result) { return Status::kSuccess; } else { CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); return Status::kErrorInternal; } } // // Non-static launch overloads that first create and set the internal params struct of this kernel handle. // /// Launches the kernel after first constructing Params internal state from supplied arguments. Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (Status::kSuccess == status) { status = run(params_, stream); } return status; } /// Launches the kernel after first constructing Params internal state from supplied arguments. Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { return run(args, workspace, stream); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status run(cudaStream_t stream = nullptr) { return run(params_, stream); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status operator()(cudaStream_t stream = nullptr) { return run(params_, stream); } }; //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::device //////////////////////////////////////////////////////////////////////////////// ================================================ FILE: csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once // common #include "cutlass/cutlass.h" #include "cutlass/kernel_hardware_info.hpp" #include "cute/tensor.hpp" #include "../device/fmha.hpp" #include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" #include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" #include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" #include "../kernel/fmha_kernel_bwd_convert.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::fmha::device { //////////////////////////////////////////////////////////////////////////////// ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// template< class ProblemShape, class Element, class ElementAccumulator, class TileShape, bool IsMla, class Mask > class Sm100FmhaBwd { public: /// Argument structure: User API struct Arguments { // Q K D D_VO HB ProblemShape problem_shape; const Element* ptr_Q; cute::tuple> stride_Q; const Element* ptr_K; cute::tuple> stride_K; const Element* ptr_V; cute::tuple> stride_V; const Element* ptr_O; cute::tuple> stride_O; const ElementAccumulator* ptr_LSE; cute::tuple> stride_LSE; const Element* ptr_dO; cute::tuple> stride_dO; Element* ptr_dQ; cute::tuple> stride_dQ; Element* ptr_dK; cute::tuple> stride_dK; Element* ptr_dV; cute::tuple> stride_dV; ElementAccumulator softmax_scale; cutlass::KernelHardwareInfo hw_info; }; using OperationSumOdO = cutlass::fmha::device::FMHA< cutlass::fmha::kernel::FmhaKernelBwdSumOdO >; using OperationConvert = cutlass::fmha::device::FMHA< cutlass::fmha::kernel::FmhaKernelBwdConvert >; using OperationMha= cutlass::fmha::device::FMHA< cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< ProblemShape, Element, ElementAccumulator, TileShape, Mask > >; using OperationMla = cutlass::fmha::device::FMHA< cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< ProblemShape, Element, ElementAccumulator, TileShape, Mask > >; using Operation = std::conditional_t; using Kernel = typename Operation::Kernel; struct Params { OperationSumOdO op_sum_OdO; Operation op; OperationConvert op_convert; ElementAccumulator* dQ_acc; size_t dQ_acc_size; }; private: Params params_; static typename OperationSumOdO::Arguments to_sum_OdO_arguments( Arguments const& args, ElementAccumulator* sum_odo = nullptr, ElementAccumulator* scaled_lse = nullptr) { using namespace cute; auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); auto log2_e = log2f(expf(1.0f)); return typename OperationSumOdO::Arguments { args.problem_shape, args.ptr_O, args.stride_O, args.ptr_dO, args.stride_dO, sum_odo, stride_sum_OdO, args.ptr_LSE, args.stride_LSE, scaled_lse, stride_scaled_lse, -1.0f, -log2_e }; } static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { using namespace cute; auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); return typename OperationConvert::Arguments { args.problem_shape, src, stride_src_dQ, nullptr, stride_src_dQ, nullptr, stride_src_dQ, args.ptr_dQ, args.stride_dQ, nullptr, args.stride_dK, nullptr, args.stride_dV, args.softmax_scale }; } static typename Operation::Arguments to_bwd_arguments( Arguments const& args, ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { return typename Operation::Arguments{ args.problem_shape, { args.ptr_Q, args.stride_Q, args.ptr_K, args.stride_K, args.ptr_V, args.stride_V, args.ptr_dO, args.stride_dO, scaled_lse, stride_scaled_lse, sum_OdO, stride_sum_OdO, dQ_acc, stride_dQ, args.softmax_scale }, { args.ptr_dK, args.stride_dK, args.ptr_dV, args.stride_dV }, args.hw_info }; } public: /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const& args) { Status status = Status::kSuccess; status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); if (status != Status::kSuccess) { return status; } status = OperationConvert::can_implement(to_convert_arguments(args)); if (status != Status::kSuccess) { return status; } status = Operation::can_implement(to_bwd_arguments(args)); if (status != Status::kSuccess) { return status; } return status; } /// Gets the workspace size static size_t get_workspace_size(Arguments const& args) { auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment size_t workspace_bytes = 0; // OdO vector workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // scaled LSE vector workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // FP32 versions of outputs that are churned (start off with Q only) workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D; return workspace_bytes; } /// Initializes state from arguments. Status initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); params_.dQ_acc = dQ_acc; params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D; auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_convert = to_convert_arguments(args, dQ_acc); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); params_.op_convert.initialize(args_convert, nullptr, stream); auto args_bwd = to_bwd_arguments( args, sum_OdO, args_sum_OdO.stride_sum_OdO, scaled_lse, args_sum_OdO.stride_scaled_lse, dQ_acc, args_convert.stride_src_dQ ); params_.op.initialize(args_bwd, nullptr, stream); return Status::kSuccess; } /// Initializes state from arguments. Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("Universal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); } /// Primary run() entry point API that is static allowing users to create and manage their own params. /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() static Status run(Params& params, cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); Status result = Status::kSuccess; result = params.op_sum_OdO.run(stream); if (result != Status::kSuccess) { return result; } auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); if (cuda_result != cudaSuccess) { return Status::kErrorInternal; } result = params.op.run(stream); if (result != Status::kSuccess) { return result; } result = params.op_convert.run(stream); if (result != Status::kSuccess) { return result; } return Status::kSuccess; } // // Non-static launch overloads that first create and set the internal params struct of this kernel handle. // /// Launches the kernel after first constructing Params internal state from supplied arguments. Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (Status::kSuccess == status) { status = run(params_, stream); } return status; } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status run(cudaStream_t stream = nullptr) { return run(params_, stream); } }; //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::fmha::device //////////////////////////////////////////////////////////////////////////////// ================================================ FILE: csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu ================================================ #include "interface.h" #include #include #include #include "common/mask.cuh" #include "common/utils.hpp" #include "fmha_cutlass_bwd_sm100.cuh" template void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { static constexpr bool IsVarlen = std::is_same_v; static constexpr bool IsMla = std::is_same_v; using TileShape = std::conditional_t, Shape<_128, _128, _128, _128>>; run_fmha_bwd(workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, total_seqlen_kv); } void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { const c10::cuda::OptionalCUDAGuard device_guard(q.device()); int head_dim_qk = q.size(-1); int head_dim_vo = v.size(-1); MaskMode mask_mode = static_cast(mask_mode_code); auto scalar_type_in = q.scalar_type(); auto scalar_type_out = o.scalar_type(); if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { using Element = cutlass::bfloat16_t; using ElementOut = cutlass::bfloat16_t; auto apply_config = [&](auto fn) { if (mask_mode == MaskMode::kCausal) { if(is_varlen) { fn(CausalForBackwardMask{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(CausalForBackwardMask{}, cute::false_type{}, Element{}, ElementOut{}); } } else { if(is_varlen) { fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{}); } } }; apply_config([&](auto mask, auto varlen, auto in, auto out) { if (head_dim_qk == 192 && head_dim_vo == 128) { call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, max_seqlen_kv); } else if (head_dim_qk == 128 && head_dim_vo == 128) { call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, max_seqlen_kv); } else { std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; } }); } else { FLASH_MLA_ASSERT(false); } } ================================================ FILE: csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include "common/utils.hpp" #include "collective/fmha_fusion.hpp" #include "device/fmha_device_bwd.hpp" #include #include using namespace cute; using namespace cutlass::fmha::kernel; using namespace cutlass::fmha::collective; using namespace cutlass::fmha; using namespace cutlass; template< class DType, bool kIsVarlen, bool kIsMla, class TileShape, class ActiveMask > struct BwdRunner { using Element = DType; using ElementAccumulator = float; // Q K D D_VO (H B) using ProblemShape = std::conditional_t< kIsVarlen, cute::tuple>, cute::tuple> >; using Operation = cutlass::fmha::device::Sm100FmhaBwd; using TensorStride = Stride>; using StrideQ = TensorStride; // Seq DQK (H B) using StrideK = TensorStride; // Seq DQK (H B) using StrideV = TensorStride; // Seq DVO (H B) using StrideO = TensorStride; // Seq DVO (H B) using StrideLSE = Stride<_1, Stride>; // Seq (H B) // Backwards specific using StrideDQ = TensorStride; using StrideDK = TensorStride; // Seq DQK (H B) using StrideDV = TensorStride; // Seq DVO (H B) using StrideDO = TensorStride; static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { const at::cuda::CUDAGuard device_guard{(char)q.get_device()}; const int device_id = q.get_device(); cutlass::KernelHardwareInfo hw_info; hw_info.device_id =device_id; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); ProblemShape problem_shape; cute::tuple> tensor_shape; int d = q.size(-1); int d_vo = v.size(-1); int batch_size = cumulative_seqlen_q.size(0) - 1; int num_qo_heads = q.size(1); int total_seqlen_q = q.size(0); int total_seqlen_kv = k.size(0); //varlen: q: [Q, H, D] //fixedlen: q: [B, H, Q, D] if constexpr (kIsVarlen) { problem_shape = cute::make_tuple( VariableLength{max_seqlen_q, static_cast(cumulative_seqlen_q.data_ptr()), total_seqlen_q}, VariableLength{max_seqlen_kv, static_cast(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv}, d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1)); } else { int q_len = total_seqlen_q / batch_size; int kv_len = total_seqlen_kv / batch_size; problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); tensor_shape = problem_shape; } auto [Q, K, D, D_VO, HB] = tensor_shape; auto [H, B] = HB; int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2); int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2); int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2); int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2); TORCH_CHECK(q_stride2 == 1); TORCH_CHECK(k_stride2 == 1); TORCH_CHECK(v_stride2 == 1); TORCH_CHECK(o_stride2 == 1); TORCH_CHECK(lse_stride0 == 1); TORCH_CHECK(dq_stride2 == 1); TORCH_CHECK(dk_stride2 == 1); TORCH_CHECK(dv_stride2 == 1); TORCH_CHECK(do_stride2 == 1); StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q)); StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K)); StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K)); StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q)); StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q)); StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q)); StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K)); StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K)); StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q)); typename Operation::Arguments arguments{ problem_shape, (static_cast(q.data_ptr())), stride_Q, (static_cast(k.data_ptr())), stride_K, (static_cast(v.data_ptr())), stride_V, (static_cast(o.data_ptr())), stride_O, (static_cast(lse.data_ptr())), stride_LSE, (static_cast(d_o.data_ptr())), stride_dO, (static_cast(dq.data_ptr())), stride_dQ, (static_cast(dk.data_ptr())), stride_dK, (static_cast(dv.data_ptr())), stride_dV, static_cast(softmax_scale), hw_info }; Operation op; uint8_t* workspace_ptr = static_cast(workspace_buffer.data_ptr()); CUTLASS_CHECK(op.can_implement(arguments)); CUTLASS_CHECK(op.initialize(arguments, workspace_ptr)); CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); } }; template void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { BwdRunner::run(workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, total_seqlen_kv); } ================================================ FILE: csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu ================================================ #include "interface.h" #include #include #include #include "common/mask.cuh" #include "common/utils.hpp" #include "fmha_cutlass_fwd_sm100.cuh" template void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { static constexpr bool IsVarlen = std::is_same_v; static constexpr bool IsMla = std::is_same_v; static constexpr bool IsCausalMask = std::is_same_v>; using Option = std::conditional_t, Option>; run_fmha_fwd( workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, softmax_scale, max_seqlen_q, max_seqlen_kv); } void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, int mask_mode_code, float sm_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { const c10::cuda::OptionalCUDAGuard device_guard(q.device()); CHECK(q.scalar_type() == k.scalar_type()); auto scalar_type_in = q.scalar_type(); auto scalar_type_out = o.scalar_type(); int head_dim_qk = q.size(-1); int head_dim_vo = v.size(-1); MaskMode mask_mode = static_cast(mask_mode_code); if (scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { using Element = cutlass::bfloat16_t; using ElementOut = cutlass::bfloat16_t; auto apply_config = [&](auto fn) { if (mask_mode == MaskMode::kCausal) { if (is_varlen) { fn(CausalMask{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(CausalMask{}, cute::false_type{}, Element{}, ElementOut{}); } } else { if (is_varlen) { fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{}); } } }; apply_config([&](auto mask, auto varlen, auto in, auto out) { if (head_dim_qk == 192 && head_dim_vo == 128) { call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, max_seqlen_q, max_seqlen_kv); } else if (head_dim_qk == 128 && head_dim_vo == 128) { call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, max_seqlen_q, max_seqlen_kv); } else { std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; } }); } else { FLASH_MLA_ASSERT(false); } } ================================================ FILE: csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh ================================================ #pragma once #include "collective/fmha_fusion.hpp" #include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" #include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" #include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp" #include "cutlass/cutlass.h" #include "cutlass/kernel_hardware_info.h" #include "device/fmha.hpp" #include "kernel/fmha_causal_tile_scheduler.hpp" #include "kernel/fmha_options.hpp" #include "kernel/fmha_tile_scheduler.hpp" #include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" #include #include #include using namespace cute; using namespace cutlass::fmha::collective; using namespace cutlass::fmha::kernel; using namespace cutlass::fmha::device; struct FmhaOptions { int b = 1; int h = 1; int h_k = 1; int q = 256; int k = 256; int d = 128; }; struct MlaOptions { int b = 1; int h = 1; int h_k = 1; int q = 256; int k = 256; int dl = 128; // headdim latent int dr = 64; // headdim rope }; template struct FwdRunner { using Element = Element_; using ElementAccumulatorQK = float; using ElementAccumulatorPV = float; using ElementOut = ElementOut_; using HeadDimLatent = _128; using HeadDim = Shape; using TileShapeMla = Shape<_256, _128, HeadDim>; using TileShapeFmha = Shape<_256, _128, _128>; using TileShape = std::conditional_t; using ProblemShapeRegular = std::conditional_t< kIsMla, cute::tuple, cute::tuple, int>>, cute::tuple, int>>>; using ProblemShapeVarlen = std::conditional_t, cute::tuple, int>>, cute::tuple, int>>>; using ProblemShapeType = std::conditional_t; using StrideQ = cute::tuple, int>>; using StrideK = cute::tuple, int>>; using StrideV = StrideK; using StrideO = StrideQ; using StrideLSE = cute::tuple<_1, cute::tuple, int>>; static constexpr bool kIsPersistent = find_option_t::value; using TileScheduler = std::conditional_t< kIsPersistent, std::conditional_t> || std::is_same_v>, cutlass::fmha::kernel::CausalPersistentTileScheduler, cutlass::fmha::kernel::PersistentTileScheduler>, std::conditional_t>; static constexpr bool IsOrderLoadEpilogue = kIsPersistent && (sizeof(Element) == sizeof(ElementOut)); using OrderLoadEpilogue = std::conditional_t; using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized< Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK, StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>; using OperationMla = cutlass::fmha::device::FMHA, TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>; using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK, StrideV, ActiveMask>; using OperationFmha = cutlass::fmha::device::FMHA, TileScheduler>>; using Mainloop = std::conditional_t; using Operation = std::conditional_t; // // Data members // /// Initialization StrideQ stride_Q; StrideK stride_K; StrideV stride_V; StrideO stride_O; StrideLSE stride_LSE; template auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv, int total_seqlen_q, int total_seqlen_kv) { int num_batches = get<3, 1>(problem_size); ProblemShape problem_size_for_init = problem_size; get<3, 1>(problem_size_for_init) = 1; get<0>(problem_size_for_init) = total_seqlen_q; get<1>(problem_size_for_init) = total_seqlen_kv; ProblemShapeType problem_size_for_launch; get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q}; get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv}; get<2>(problem_size_for_launch) = get<2>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size); return cute::make_tuple(problem_size_for_init, problem_size_for_launch); } template static constexpr auto get_problem_shape(const Options &options) { int h_r = options.h / options.h_k; if constexpr (std::is_same_v) { return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr), cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); } else { return cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); } } template ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv, int total_seqlen_q, int total_seqlen_kv, void *cumulative_length_q, void *cumulative_length_kv) { assert(options.h % options.h_k == 0); auto problem_shape_in = get_problem_shape(options); ProblemShapeType problem_shape; decltype(problem_shape_in) problem_size; if constexpr (kIsVarlen) { auto [problem_shape_init, problem_shape_launch] = initialize_varlen( problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv); problem_shape = problem_shape_launch; problem_size = problem_shape_init; } else { problem_size = problem_shape_in; problem_shape = problem_shape_in; } auto get_head_dimension = [&]() { if constexpr (rank_v(problem_shape))> == 2) { return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape), size<2, 0>(problem_shape)); } else { return cute::make_tuple(size<2>(problem_size), size<2>(problem_size)); } }; if constexpr (kIsVarlen) { get<0>(problem_shape).cumulative_length = static_cast(cumulative_length_q); get<1>(problem_shape).cumulative_length = static_cast(cumulative_length_kv); } return problem_shape; } auto get_arguments(const ProblemShapeType &problem_shape, const cutlass::KernelHardwareInfo &hw_info, float scale_softmax, void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, void *cumulative_length_q, void *cumulative_length_kv) { auto problem_shape_ = problem_shape; typename Operation::Arguments arguments{ problem_shape_, {static_cast(q_ptr), stride_Q, static_cast(k_ptr), stride_K, static_cast(v_ptr), stride_V, scale_softmax}, {static_cast(o_ptr), stride_O, static_cast(lse_ptr), stride_LSE}, hw_info}; return arguments; } template void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax, at::Tensor workspace, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) { int total_seqlen_q = q.size(0); int total_seqlen_kv = k.size(0); ProblemShapeType problem_shape = initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); int SQ = size<0>(problem_shape); int SK = size<1>(problem_shape); int B = size<3, 1>(problem_shape); int H = size<3, 0>(problem_shape); int H_K = size<3, 0, 1>(problem_shape); int H_Q = size<3, 0, 0>(problem_shape); int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); TORCH_CHECK(q_stride2 == 1); TORCH_CHECK(k_stride2 == 1); TORCH_CHECK(v_stride2 == 1); TORCH_CHECK(o_stride2 == 1); TORCH_CHECK(lse_stride0 == 1); stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0)); stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0)); stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0)); stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0)); stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ)); if constexpr (kIsVarlen) { get<2, 1>(stride_Q) = 0; get<2, 1>(stride_K) = 0; get<2, 1>(stride_V) = 0; get<2, 1>(stride_O) = 0; get<1, 1>(stride_LSE) = 0; } typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(), v.data_ptr(), o.data_ptr(), lse.data_ptr(), cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); Operation op; // size_t workspace_size = 0; // workspace_size = Operation::get_workspace_size(arguments); // todo: if use workspace, need check workspace size first. // we don't use workspace in current version. CUTLASS_CHECK(op.can_implement(arguments)); CUTLASS_CHECK(op.initialize(arguments, nullptr)); CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); } }; template void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) { const at::cuda::CUDAGuard device_guard{(char)q.get_device()}; const int device_id = q.get_device(); cutlass::KernelHardwareInfo hw_info; hw_info.device_id = device_id; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); auto get_options = [&]() { if constexpr (kIsMla) { MlaOptions options; options.b = cumulative_seqlen_q.size(0) - 1; options.h = q.size(1); options.h_k = k.size(1); options.q = q.size(0) / options.b; options.k = k.size(0) / options.b; options.dl = v.size(-1); options.dr = q.size(-1) - v.size(-1); return options; } else { FmhaOptions options; options.b = cumulative_seqlen_q.size(0) - 1; options.h = q.size(1); options.h_k = k.size(1); options.q = q.size(0) / options.b; options.k = k.size(0) / options.b; options.d = q.size(-1); return options; } }; auto options = get_options(); if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && (std::is_same_v> || std::is_same_v>)) { FwdRunner runner; runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); } else { FwdRunner runner; runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); } } ================================================ FILE: csrc/sm100/prefill/dense/interface.h ================================================ #pragma once #include void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); ================================================ FILE: csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" namespace cutlass::fmha::kernel { //////////////////////////////////////////////////////////////////////////////// // Swizzle Q tile and H tile to improve L2 cache hit rate, // and launch the longest main loop first to keep most SMs busy. struct CausalIndividualTileScheduler { static constexpr int TileQ = 16; static constexpr int TileH = 8; static constexpr int TileSize = TileQ * TileH; struct Params { dim3 grid; int tile_max_q; FastDivmod divmod_tile_col; FastDivmod divmod_tile_size; FastDivmod divmod_tile_head; }; bool valid_ = true; Params params; CUTLASS_DEVICE CausalIndividualTileScheduler(Params const& params) : params(params) {} template static Params to_underlying_arguments( ProblemSize const& problem_size, KernelHardwareInfo hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { using namespace cute; dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); // gridDim.x must multiple of TileH const int tile_col_count = grid.x / TileH; const int tile_max_q = grid.y / TileQ * TileQ; return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; } static dim3 get_grid_shape(Params const& params) { return params.grid; } CUTLASS_DEVICE bool is_valid() { return valid_; } CUTLASS_DEVICE auto get_block_coord() { using namespace cute; const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; int tile_idx, tile_tail; params.divmod_tile_size(tile_idx, tile_tail, block_idx); int tile_row_idx, tile_col_idx; params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); int row_offset_in_tail, col_offset_in_tail; params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; const int col_idx = tile_col_idx * TileH + col_offset_in_tail; // last q tile launch first if(blockIdx.y >= params.tile_max_q) { return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); } return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); } CUTLASS_DEVICE CausalIndividualTileScheduler& operator++() { valid_ = false; return *this; } }; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // Launch order: H Q B struct CausalPersistentTileScheduler { struct Params { int num_blocks; FastDivmod divmod_h; FastDivmod divmod_m_block; FastDivmod divmod_b; KernelHardwareInfo hw_info; }; int block_idx = 0; Params params; CUTLASS_DEVICE CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} template static Params to_underlying_arguments( ProblemSize const& problem_size, KernelHardwareInfo hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { using namespace cute; // Get SM count if needed, otherwise use user supplied SM count int sm_count = hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); hw_info.sm_count = sm_count; int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); return Params { num_blocks, { size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) }, hw_info }; } static dim3 get_grid_shape(Params const& params) { dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); return grid; } CUTLASS_DEVICE bool is_valid() { return block_idx < params.num_blocks; } CUTLASS_DEVICE auto get_block_coord() { using namespace cute; int block_decode = block_idx; int m_block, bidb, bidh; params.divmod_h(block_decode, bidh, block_decode); params.divmod_m_block(block_decode, m_block, block_decode); params.divmod_b(block_decode, bidb, block_decode); return make_coord(m_block, _0{}, make_coord(bidh, bidb)); } CUTLASS_DEVICE CausalPersistentTileScheduler& operator++() { block_idx += gridDim.x; return *this; } }; //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/layout.hpp" #include // for KERUTILS_ENABLE_SM100A namespace cutlass::fmha::kernel { using namespace cute; template struct FmhaKernelBwdConvert { struct Arguments { ProblemShape problem_shape; const ElementAcc* ptr_src_dQ; tuple> stride_src_dQ; const ElementAcc* ptr_src_dK; tuple> stride_src_dK; const ElementAcc* ptr_src_dV; tuple> stride_src_dV; Element* ptr_dest_dQ; tuple> stride_dest_dQ; Element* ptr_dest_dK; tuple> stride_dest_dK; Element* ptr_dest_dV; tuple> stride_dest_dV; ElementAcc scale = 1.0; }; using Params = Arguments; using ClusterShape = Shape<_1, _1, _1>; static constexpr int SharedStorageSize = 0; static const int MinBlocksPerMultiprocessor = 1; static const int MaxThreadsPerBlock = 128; using ArchTag = cutlass::arch::Sm90; static const int kBlockSeq = 8; static size_t get_workspace_size(Arguments const& args) { return 0; } static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { return cutlass::Status::kSuccess; } static const int kNumThreadsD = 16; static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; static const int kElementsPerLoad = 4; static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; static bool can_implement(Arguments const& args) { return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); return grid; } static dim3 get_block_shape() { dim3 block(kNumThreadsD, kNumThreadsSeq, 1); return block; } static Params to_underlying_arguments(Arguments const& args, void* workspace) { return args; } template CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; int seqlen = count; if constexpr (is_variable_length_v) { int offset = count.cumulative_length[blockIdx.y]; ptr_dest_bh += offset * get<0>(stride_dest); seqlen = count.cumulative_length[blockIdx.y + 1] - offset; } for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { int idx_s = idx_s_t + kBlockSeq * blockIdx.z; if (idx_s >= seqlen) continue; auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { ElementAcc value_src[kElementsPerLoad]; Element value_dest[kElementsPerLoad]; using VecSrc = uint_bit_t * kElementsPerLoad>; using VecDest = uint_bit_t * kElementsPerLoad>; *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); for (int v = 0; v < kElementsPerLoad; v++) { value_dest[v] = static_cast(params.scale * value_src[v]); } *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); } } } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { #if defined(KERUTILS_ENABLE_SM100A) if (params.ptr_src_dQ != nullptr) { copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); } if (params.ptr_src_dK != nullptr) { copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); } if (params.ptr_src_dV != nullptr) { copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); } #endif } }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/layout.hpp" #include // for KERUTILS_ENABLE_SM100A namespace cutlass::fmha::kernel { using namespace cute; template struct FmhaKernelBwdSumOdO { struct Arguments { ProblemShape problem_shape; const Element* ptr_O; cute::tuple> stride_O; const Element* ptr_dO; cute::tuple> stride_dO; ElementAcc* ptr_sum_OdO; cute::tuple> stride_sum_OdO; const ElementAcc* ptr_lse = nullptr; cute::tuple> stride_lse; ElementAcc* ptr_scaled_lse = nullptr; cute::tuple> stride_scaled_lse; ElementAcc sum_odo_scale = 1.0; ElementAcc lse_scale = 1.0; }; using Params = Arguments; using ClusterShape = Shape<_1, _1, _1>; static constexpr int SharedStorageSize = 0; static const int MinBlocksPerMultiprocessor = 1; static const int MaxThreadsPerBlock = 128; using ArchTag = cutlass::arch::Sm100; static size_t get_workspace_size(Arguments const& args) { return 0; } static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { return cutlass::Status::kSuccess; } static const int kBlockQ = 16; static const int kNumThreadsD = 8; static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; static const int kElementsPerLoad = 2; static const int kIterationsQ = kBlockQ / kNumThreadsQ; static bool can_implement(Arguments const& args) { return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); return grid; } static dim3 get_block_shape() { dim3 block(kNumThreadsD, kNumThreadsQ, 1); return block; } static Params to_underlying_arguments(Arguments const& args, void* workspace) { return args; } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { #if defined(KERUTILS_ENABLE_SM100A) auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); auto problem_q = get<0>(params.problem_shape); int seqlen_q = problem_q; if constexpr (is_variable_length_v) { int offset = problem_q.cumulative_length[blockIdx.z]; ptr_O_bh += offset * get<0>(params.stride_O); ptr_dO_bh += offset * get<0>(params.stride_dO); ptr_lse_bh += offset * get<0>(params.stride_lse); seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; } CUTLASS_PRAGMA_UNROLL for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { int idx_q = idx_q_t + kBlockQ * blockIdx.x; if (idx_q >= seqlen_q) continue; ElementAcc acc = 0; auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { Element value_O[kElementsPerLoad]; Element value_dO[kElementsPerLoad]; using Vec = uint_bit_t * kElementsPerLoad>; *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); for (int v = 0; v < kElementsPerLoad; v++) { acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]); } } for (int i = 1; i < kNumThreadsD; i *= 2) { acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); } if (threadIdx.x == 0) { *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; if (params.ptr_scaled_lse) { *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; } } } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); } #endif } }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/fmha_options.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" namespace cutlass::fmha::kernel { template struct find_option; template struct find_option { using option_value = Default; }; template struct find_option : std::conditional_t< Option::tag == kTag, Option, find_option > {}; template using find_option_t = typename find_option::option_value; enum class Tag { kIsPersistent, kNumMmaWarpGroups, kLoadsQSeparately, kIsMainloopLocked, kIsEpilogueLocked, kStagesQ, kStagesKV, kEpilogueKind, kBlocksPerSM, kClusterM, kAccQK }; template struct Option { static constexpr auto tag = kTag; using option_value = Value; }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/kernel_hardware_info.h" namespace cutlass::fmha::kernel { //////////////////////////////////////////////////////////////////////////////// struct IndividualTileScheduler { struct Params { dim3 grid; }; bool valid_ = true; CUTLASS_DEVICE IndividualTileScheduler(Params const&) {} template static Params to_underlying_arguments( ProblemSize const& problem_size, KernelHardwareInfo hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { using namespace cute; dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); return Params{ grid }; } static dim3 get_grid_shape(Params const& params) { return params.grid; } CUTLASS_DEVICE bool is_valid() { return valid_; } CUTLASS_DEVICE auto get_block_coord() { using namespace cute; return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); } CUTLASS_DEVICE IndividualTileScheduler& operator++() { valid_ = false; return *this; } }; //////////////////////////////////////////////////////////////////////////////// struct PersistentTileScheduler { struct Params { int num_blocks; FastDivmod divmod_m_block; FastDivmod divmod_h; FastDivmod divmod_b; KernelHardwareInfo hw_info; }; int block_idx = 0; Params params; CUTLASS_DEVICE PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} template static Params to_underlying_arguments( ProblemSize const& problem_size, KernelHardwareInfo hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { using namespace cute; // Get SM count if needed, otherwise use user supplied SM count int sm_count = hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); hw_info.sm_count = sm_count; int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); return Params { num_blocks, { max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, hw_info }; } static dim3 get_grid_shape(Params const& params) { dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); return grid; } CUTLASS_DEVICE bool is_valid() { return block_idx < params.num_blocks; } CUTLASS_DEVICE auto get_block_coord() { using namespace cute; int block_decode = block_idx; int m_block, bidb, bidh; params.divmod_m_block(block_decode, m_block, block_decode); params.divmod_b(block_decode, bidb, block_decode); params.divmod_h(block_decode, bidh, block_decode); return make_coord(m_block, _0{}, make_coord(bidh, bidb)); } CUTLASS_DEVICE PersistentTileScheduler& operator++() { block_idx += gridDim.x; return *this; } }; //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/arch/simd_sm100.hpp" #include "cutlass/arch/arch.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include // for KERUTILS_ENABLE_SM100A #include "../collective/fmha_common.hpp" #include namespace cutlass::fmha::kernel { using namespace cutlass::fmha::collective; using namespace cute; template< class ProblemShape, class Element, class ElementAcc, class TileShape, class Mask > struct Sm100FmhaBwdKernelTmaWarpSpecialized { using TileShapeQ = decltype(get<0>(TileShape{})); static_assert(std::is_same_v, "tile shape K must be 128"); using TileShapeK = decltype(get<1>(TileShape{})); static_assert(std::is_same_v, "tile shape K must be 128"); using TileShapeDQK = decltype(get<2>(TileShape{})); using TileShapeDVO = decltype(get<2>(TileShape{})); using TmemAllocator = cute::TMEM::Allocator1Sm; struct TmemAllocation { static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); static constexpr uint32_t kP = kS; static constexpr uint32_t kTotal = kS + TileShapeQ{}; }; static_assert( static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem" ); enum class WarpRole { Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 }; static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; static constexpr int kNumComputeWarps = 8; static constexpr int kNumReduceWarps = 4; CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); } struct RegisterAllocation { static constexpr int kWarpgroup0 = 160-8; static constexpr int kWarpgroup1 = 128; static constexpr int kWarpgroup2 = 96; static constexpr int kReduce = kWarpgroup0; static constexpr int kCompute = kWarpgroup1; static constexpr int kMma = kWarpgroup2; static constexpr int kEmpty = kWarpgroup2; static constexpr int kLoad = kWarpgroup2; static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); }; using ArchTag = cutlass::arch::Sm100; using ClusterShape = Shape<_1, _1, _1>; using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; static constexpr int MinBlocksPerMultiprocessor = 1; static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; static constexpr int Alignment = 128 / sizeof_bits_v; static constexpr int kStages = 2; using TensorStrideContiguousK = Stride>; using TensorStrideContiguousMN = Stride<_1, int, Stride>; // compute S using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeKQ = typename CollectiveMmaKQ::TileShape; using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; // compute dP using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeVDO = typename CollectiveMmaVDO::TileShape; using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; // compute dV using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // needs to match ordering of S calculation Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapePDO = typename CollectiveMmaPDO::TileShape; using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); // compute dK using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // somewhat arbitrary since we dump to smem, need to agree with the next one Element, TensorStrideContiguousK , Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; // compute dQ using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // somewhat arbitrary since we dump to smem, need to agree with the previous one Element, TensorStrideContiguousMN, Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeDSK = typename CollectiveMmaDSK::TileShape; using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; // pipelines are named Pipeline static constexpr int kStagesComputeSmem = 1; using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; using PipelineLoadComputeLSE = PipelineAsync<1>; using PipelineLoadComputeSumOdO = PipelineAsync<1>; using PipelineMmaComputeS = PipelineUmmaAsync<1>; using PipelineMmaComputeDP = PipelineUmmaAsync<1>; using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; static constexpr int kStagesReduceTmaStore = 2; using PipelineReduceTmaStore = PipelineTmaStore; struct PipelineStorage { alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; }; template static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { return composition(layout, make_tuple(_, _, _, make_layout(stages))); } using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); using SmemLayoutLSE = Layout>; using SmemLayoutSumOdO = Layout>; using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); using TileShapeDQ = _32; using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ >()); using SmemShapeDQ = Shape>; using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); struct TensorStorage { union { alignas(2048) cute::array> smem_k; alignas(2048) cute::array> smem_k_t; }; alignas(2048) cute::array> smem_v; union { alignas(2048) cute::array> smem_q; alignas(2048) cute::array> smem_q_t; }; union { alignas(2048) cute::array> smem_do; alignas(2048) cute::array> smem_do_t; }; union { alignas(2048) cute::array> smem_ds; alignas(2048) cute::array> smem_ds_t; }; alignas(1024) cute::array> smem_dq; alignas(16) cute::array> smem_lse; alignas(16) cute::array> smem_sum_odo; }; static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); struct SharedStorage { TensorStorage tensors; PipelineStorage pipelines; uint32_t tmem_base_ptr; }; // this is tight enough that it won't work with sizeof due to padding for alignment static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); using TensorStride = TensorStrideContiguousK; // S D (H B) using RowTensorStride = Stride<_1, Stride>; // S (H B) struct MainloopArguments { const Element* ptr_q; TensorStride stride_q; const Element* ptr_k; TensorStride stride_k; const Element* ptr_v; TensorStride stride_v; const Element* ptr_do; TensorStride stride_do; const ElementAcc* ptr_lse; RowTensorStride stride_lse; const ElementAcc* ptr_sum_odo; RowTensorStride stride_sum_odo; ElementAcc* ptr_dq_acc; TensorStride stride_dq_acc; ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); }; using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), SmemLayoutDQ{}(_, _, _0{}) )); struct MainloopParams { TMA_K tma_load_k; TMA_V tma_load_v; TMA_Q tma_load_q; TMA_DO tma_load_do; TMA_DQ tma_red_dq; }; struct EpilogueArguments { Element* ptr_dk; TensorStride stride_dk; Element* ptr_dv; TensorStride stride_dv; }; struct Arguments { ProblemShape problem_shape; MainloopArguments mainloop; EpilogueArguments epilogue; KernelHardwareInfo hw_info; }; struct Params { ProblemShape problem_shape; MainloopArguments mainloop; MainloopParams mainloop_params; EpilogueArguments epilogue; KernelHardwareInfo hw_info; }; static bool can_implement(Arguments const& args) { auto [Q, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { return false; } if (D % Alignment != 0 || D_VO % Alignment != 0) { return false; } return true; } static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { return Status::kSuccess; } static Params to_underlying_arguments(Arguments const& args, void*) { auto [Q_, K_, D, D_VO, HB] = args.problem_shape; int Q = Q_; int K = K_; if constexpr (is_variable_length_v) { Q = Q_.total_length; } if constexpr (is_variable_length_v) { K = K_.total_length; } auto params_kq = CollectiveMmaKQ::to_underlying_arguments( make_shape(K, Q, D, HB), typename CollectiveMmaKQ::Arguments { args.mainloop.ptr_k, args.mainloop.stride_k, args.mainloop.ptr_q, args.mainloop.stride_q, }, /*workspace=*/nullptr); auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( make_shape(K, Q, D_VO, HB), typename CollectiveMmaVDO::Arguments { args.mainloop.ptr_v, args.mainloop.stride_v, args.mainloop.ptr_do, args.mainloop.stride_do, }, /*workspace=*/nullptr); TMA_DQ tma_red_dq = make_tma_copy( SM90_TMA_REDUCE_ADD{}, make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), SmemLayoutDQ{}(_, _, _0{}) ); return Params{ args.problem_shape, args.mainloop, MainloopParams{ params_kq.tma_load_a, params_vdo.tma_load_a, params_kq.tma_load_b, params_vdo.tma_load_b, tma_red_dq }, args.epilogue, args.hw_info }; } template static CUTLASS_DEVICE auto quantize(T const& input) { constexpr int AlignmentS = 4; auto output = make_tensor(shape(input)); auto input_vec = recast>(input); auto output_vec = recast>(output); cutlass::NumericArrayConverter epilogue_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(input_vec); i++) { output_vec(i) = epilogue_op(input_vec(i)); } return output; } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, TensorStorage& shared_tensors, PipelineLoadMmaQ& pipeline_load_mma_q, typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, PipelineLoadComputeLSE& pipeline_load_compute_lse, typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; using X = Underscore; uint16_t mcast_mask = 0; auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); auto tSTgK = cta_mma_kq.partition_A(gK); auto tSTgQ = cta_mma_kq.partition_B(gQ); auto tDPTgV = cta_mma_vdo.partition_A(gV); auto tDPTgDO = cta_mma_vdo.partition_B(gDO); auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); auto [tKgK_mkl, tKsK] = tma_partition( mainloop_params.tma_load_k, _0{}, make_layout(_1{}), group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); auto [tQgQ_mkl, tQsQ] = tma_partition( mainloop_params.tma_load_q, _0{}, make_layout(_1{}), group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); auto [tVgV_mkl, tVsV] = tma_partition( mainloop_params.tma_load_v, _0{}, make_layout(_1{}), group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); auto [tDOgDO_mkl, tDOsDO] = tma_partition( mainloop_params.tma_load_do, _0{}, make_layout(_1{}), group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); // set up lse and sum_odo auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); // load K if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), tKsK(_, _0{}) ); } // load Q if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQsQ(_, pipeline_load_mma_q_producer_state.index()) ); } ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); // load LSE // 32 threads loading 128 values of 32b each // so 4*32b=128b int thread_idx = threadIdx.x % NumThreadsPerWarp; int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); for (int i = 0; i < 4; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_lse.begin() + smem_idx + i, &mLSE(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); // load V if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), tVsV(_, _0{}) ); } // load dO if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOsDO(_, pipeline_load_mma_do_producer_state.index()) ); } ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); for (int i = 0; i < 4; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_sum_odo.begin() + smem_idx + i, &mSumOdO(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; iter_count -= 1; iter_index += 1; while (iter_count > 0) { pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); // load Q if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQsQ(_, pipeline_load_mma_q_producer_state.index()) ); } ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); // load LSE smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; for (int i = 0; i < 4; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_lse.begin() + smem_idx + i, &mLSE(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); // load dO if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOsDO(_, pipeline_load_mma_do_producer_state.index()) ); } ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; for (int i = 0; i < 4; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_sum_odo.begin() + smem_idx + i, &mSumOdO(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; iter_count -= 1; iter_index += 1; } } template CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, PipelineLoadMmaQ& pipeline_load_mma_q, typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); tDVrP.data() = TmemAllocation::kP; Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); TiledMmaKQ tiled_mma_kq; TiledMmaVDO tiled_mma_vdo; TiledMmaDSK tiled_mma_dsk; TiledMmaDSQ tiled_mma_dsq; TiledMmaPDO tiled_mma_pdo; tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); tSTtST.data() = TmemAllocation::kS; Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); tDPTtDPT.data() = TmemAllocation::kDP; Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); tDQtDQ.data() = TmemAllocation::kDQ; Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); tDKtDK.data() = TmemAllocation::kDK; Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); tDVtDV.data() = TmemAllocation::kDV; auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); // S = Q*K tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { cute::gemm(tiled_mma_kq, tSTrK(_,_,k_block,_0{}), tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), tSTtST); tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; } ++pipeline_load_mma_q_consumer_state; pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); ++pipeline_mma_compute_s_producer_state; pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); // dP = dO*V tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { cute::gemm(tiled_mma_vdo, tDPTrV(_,_,k_block,_0{}), tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDPTtDPT); tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); ++pipeline_mma_compute_dp_producer_state; pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); // dV = P*dO CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { cute::gemm(tiled_mma_pdo, tDVrP(_,_,k_block), tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDVtDV); tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); ++pipeline_compute_mma_p_consumer_state; pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); ++pipeline_load_mma_do_consumer_state; iter_count -= 1; // in tmem, S & P overlap // and dP and dQ overlap // so we need to acquire dQ and dP at the same time while (iter_count > 0) { pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); // S = Q*K tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { cute::gemm(tiled_mma_kq, tSTrK(_,_,k_block,_0{}), tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), tSTtST); tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; } ++pipeline_load_mma_q_consumer_state; pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); ++pipeline_mma_compute_s_producer_state; pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); // we need to acquire dP here, because tmem dQ == tmem dP pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); // dQ = dS*K tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { cute::gemm(tiled_mma_dsk, tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDQrKT(_,_,k_block,_0{}), tDQtDQ); tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); ++pipeline_mma_reduce_dq_producer_state; // dK = dS*Q CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { cute::gemm(tiled_mma_dsq, tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), tDKtDK); tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; } pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); ++pipeline_load_mma_q_release_state; pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); ++pipeline_compute_mma_ds_consumer_state; // we grab dq here, because in tmem dq == dp pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); // dP = dO*V tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { cute::gemm(tiled_mma_vdo, tDPTrV(_,_,k_block,_0{}), tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDPTtDPT); tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); ++pipeline_mma_compute_dp_producer_state; pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); // dV = P*dO CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { cute::gemm(tiled_mma_pdo, tDVrP(_,_,k_block), tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDVtDV); tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); ++pipeline_compute_mma_p_consumer_state; pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); ++pipeline_load_mma_do_consumer_state; iter_count -= 1; } // signal to the epilogue that dV is ready pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); ++pipeline_mma_compute_dkdv_producer_state; pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); // dK = dS*Q CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { cute::gemm(tiled_mma_dsq, tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), tDKtDK); tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; } // signal to epilgue that dK is ready pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); ++pipeline_mma_compute_dkdv_producer_state; // we've already acquired mma_reduce_dq in the loop // dQ = dS*K tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { cute::gemm(tiled_mma_dsk, tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDQrKT(_,_,k_block,_0{}), tDQtDQ); tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); ++pipeline_mma_reduce_dq_producer_state; pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); ++pipeline_load_mma_q_release_state; pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); ++pipeline_compute_mma_ds_consumer_state; } template CUTLASS_DEVICE void store( TensorG gmem, TensorR const& regs, TensorC const& coord, TensorShape const& tensor_shape) { Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, make_layout(make_shape(_1{}, Int{})), regs.layout() ); auto thr_copy = copy_op.get_slice(_0{}); Tensor quantized_regs = quantize(regs); Tensor tCr = thr_copy.partition_S(quantized_regs); Tensor tCg = thr_copy.partition_D(gmem); Tensor tPc = thr_copy.partition_D(preds); copy_if(copy_op, tPc, tCr, tCg); } template CUTLASS_DEVICE void epilogue_clear( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args) { auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDK = domain_offset( make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapeDSQ{})) ); auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { if (elem_less(cDK(i), select<1,2>(problem_shape))) { gDK(i) = Element(0); } } for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { if (elem_less(cDV(i), select<1,3>(problem_shape))) { gDV(i) = Element(0); } } } template CUTLASS_DEVICE void epilogue( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDKtDK.data() = TmemAllocation::kDK; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDK = domain_offset( make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapeDSQ{})) ); constexpr int kNumWarpgroups = kNumComputeWarps / 4; int dp_idx = threadIdx.x % 128; int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; auto split_wg = [&](auto const& t) { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); return p(_, _, make_coord(wg_idx, _)); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); return p(_, _, _, make_coord(wg_idx, _)); } }; auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDVtDV.data() = TmemAllocation::kDV; auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); // load tDVtDV cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); // store tDVgDV store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); ++pipeline_mma_compute_dkdv_consumer_state; pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); // load tDKtDK cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rDK); i++) { tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); } // store tDKgDK store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); ++pipeline_mma_compute_dkdv_consumer_state; } template CUTLASS_DEVICE void compute( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, TensorStorage& shared_tensors, PipelineLoadComputeLSE& pipeline_load_compute_lse, typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; // in tmem, S & P overlap // and dP and dQ overlap // there are two compute wg's that cooperatively compute softmax // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; auto store_op = []() { if constexpr (sizeof(Element) == 1) { return SM100_TMEM_STORE_32dp32b4x{}; } else { return SM100_TMEM_STORE_32dp32b8x{}; } }(); Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); tSTtST.data() = TmemAllocation::kS; Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); tDPTtDPT.data() = TmemAllocation::kDP; Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); constexpr int kNumWarpgroups = kNumComputeWarps / 4; int dp_idx = threadIdx.x % 128; int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; auto tiled_t2r = make_tmem_copy(load_op, tSTtST); auto thread_t2r = tiled_t2r.get_slice(dp_idx); auto split_wg = [&](auto const& t) { if constexpr (decltype(size<1>(t))::value > 1) { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); return p(_, make_coord(wg_idx, _), _); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); return p(_, make_coord(wg_idx, _), _, _); } } else { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); return p(_, _, make_coord(wg_idx, _)); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); return p(_, _, _, make_coord(wg_idx, _)); } } }; Tensor tTR_cST_p = thread_t2r.partition_D(cST); Tensor tTR_cST = split_wg(tTR_cST_p); Tensor tTR_rST = make_tensor(shape(tTR_cST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); Tensor tTR_cDPT = split_wg(tTR_cDPT_p); Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); tDVrP.data() = TmemAllocation::kP; auto tiled_r2t = make_tmem_copy(store_op, tDVrP); auto thread_r2t = tiled_r2t.get_slice(dp_idx); auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); auto tRT_cST_p = thread_r2t.partition_S(tDVcST); auto tRT_cST = split_wg(tRT_cST_p); bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); int last_iter = iter_count - 1 + iter_index; CUTLASS_PRAGMA_NO_UNROLL while (iter_count > 0) { // wait for S and P pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); // wait for LSE pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); auto dispatch_bool = [](bool b, auto fn) { if (b) { fn(cute::true_type{}); } else { fn(cute::false_type{}); } }; bool leading_causal_masking = false; if constexpr (std::is_base_of_v, Mask>) { leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); int kv_left = get<1>(blk_coord) * TileShapeK{}; int kv_right = kv_left + TileShapeK{} - 1; int q_left = iter_index * TileShapeQ{} + offset; int q_right = q_left + TileShapeQ{} - 1; leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); } bool trailing_residual_masking = false; if constexpr (std::is_base_of_v) { trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); } dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); if constexpr (decltype(is_masked_tile)::value) { Mask{}.apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); }, problem_shape); } ElementAcc log2_e = static_cast(M_LOG2E); float2 softmax_scale_log2_e; softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rST); i += 2) { float2 acc; float2 lse; float2 out; acc.x = tTR_rST(i); acc.y = tTR_rST(i + 1); lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); cute::fma(out, softmax_scale_log2_e, acc, lse); tTR_rST(i) = ::exp2f(out.x); tTR_rST(i+1) = ::exp2f(out.y); } auto tRT_rST = quantize(tTR_rST); auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransformBarrier ).arrive_and_wait(); cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); }); // notify for P cutlass::arch::fence_view_async_tmem_store(); pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); ++pipeline_compute_mma_p_producer_state; // release S pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); ++pipeline_mma_compute_s_consumer_state; // release LSE pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); ++pipeline_load_compute_lse_consumer_state; // wait for OdO pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); // wait for dP pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); // wait for dS // in principle, we could defer waiting for dS, and move in the freeing of dP // however, that would force us to keep dS in registers longer pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); // compute dS = dsoftmax(P, dP, sum_OdO) cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rDPT); i += 2) { float2 st; st.x = tTR_rST(i); st.y = tTR_rST(i+1); float2 dpt; dpt.x = tTR_rDPT(i); dpt.y = tTR_rDPT(i+1); float2 odo; odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); float2 dif; // sum odo is negated during preprocess cute::add(dif, dpt, odo); float2 out; cute::mul(out, dif, st); tTR_rDPT(i) = out.x; tTR_rDPT(i+1) = out.y; } auto tTR_rDST = quantize(tTR_rDPT); // release dP cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); ++pipeline_mma_compute_dp_consumer_state; Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) (_, _, _, pipeline_compute_mma_ds_producer_state.index()); auto thread_layout = make_ordered_layout( make_shape(_128{}, _128{}), make_stride(_1{}, _0{}) ); auto sDS_pi = as_position_independent_swizzle_tensor(sDS); auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); auto sDS_pi_slice = split_wg(sDS_pi_slice_p); copy_aligned(tTR_rDST, sDS_pi_slice); // notify for dS cutlass::arch::fence_view_async_shared(); pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); ++pipeline_compute_mma_ds_producer_state; // release OdO pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); ++pipeline_load_compute_sum_odo_consumer_state; iter_count -= 1; iter_index += 1; } epilogue( blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); } template CUTLASS_DEVICE void reduce( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, TensorStorage& shared_tensors, PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, PipelineReduceTmaStore& pipeline_reduce_tma_store, typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { using X = Underscore; auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; // must match TileShapeDQ auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); tDQtDQ.data() = TmemAllocation::kDQ; Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) (_, _, _, _0{}, blk_coord_batch); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); auto thread_t2r = tiled_t2r.get_slice(thread_idx); Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); Tensor tDQsDQ = block_tma.partition_S(sDQ); Tensor tDQcDQ = block_tma.partition_S(cDQ); Tensor tDQgDQ = block_tma.partition_D(gDQ); int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; while (iter_count > 0) { pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); // load dQ from tmem to rmem cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); ++pipeline_mma_reduce_dq_consumer_state; // we don't have enough smem to dump it all to smem, so we do it in stages CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<2>(tTR_cDQ); i++) { if (lane_predicate) { pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); } // wait in all threads for the acquire to complete cutlass::arch::NamedBarrier( kNumReduceWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransposeBarrier ).arrive_and_wait(); cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); // wait for the stores to all be visible to the TMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier( kNumReduceWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransposeBarrier ).arrive_and_wait(); if (lane_predicate) { // launch tma store copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); } ++pipeline_reduce_tma_store_producer_state; } iter_count -= 1; iter_index += 1; } } CUTLASS_DEVICE void operator()(Params const& params, char* smem) { #if defined(KERUTILS_ENABLE_SM100A) int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); if (role == WarpRole::Load && lane_predicate) { prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); } SharedStorage& shared_storage = *reinterpret_cast(smem); int initializing_warp = 0; typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; if (role == WarpRole::Load) { pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; } if (role == WarpRole::Mma) { pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; } pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); // Also loads K in the first iteration pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; pipeline_load_mma_q_params.initializing_warp = initializing_warp++; PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; if (role == WarpRole::Load) { pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; } if (role == WarpRole::Mma) { pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; } pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); // Also loads V in the first iteration pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; pipeline_load_mma_do_params.initializing_warp = initializing_warp++; PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; if (role == WarpRole::Load) { pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; } pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; PipelineLoadComputeLSE pipeline_load_compute_lse( shared_storage.pipelines.load_compute_lse, pipeline_load_compute_lse_params, /*barrier init*/ cute::true_type{}); typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; if (role == WarpRole::Load) { pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; } pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( shared_storage.pipelines.load_compute_sum_odo, pipeline_load_compute_sum_odo_params, /*barrier init*/ cute::true_type{}); typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; if (role == WarpRole::Mma) { pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; } pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; PipelineMmaComputeS pipeline_mma_compute_s( shared_storage.pipelines.mma_compute_s, pipeline_mma_compute_s_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; if (role == WarpRole::Mma) { pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; } pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; PipelineMmaComputeDP pipeline_mma_compute_dp( shared_storage.pipelines.mma_compute_dp, pipeline_mma_compute_dp_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; if (role == WarpRole::Mma) { pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; } if (role == WarpRole::Reduce) { pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; } pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; PipelineMmaReduceDQ pipeline_mma_reduce_dq( shared_storage.pipelines.mma_reduce_dq, pipeline_mma_reduce_dq_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; if (role == WarpRole::Mma) { pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; } if (role == WarpRole::Compute) { pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; } pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_compute_mma_p_params.consumer_arv_count = 1; pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; PipelineComputeMmaP pipeline_compute_mma_p( shared_storage.pipelines.compute_mma_p, pipeline_compute_mma_p_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; if (role == WarpRole::Mma) { pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; } if (role == WarpRole::Compute) { pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; } pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_compute_mma_ds_params.consumer_arv_count = 1; pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; PipelineComputeMmaDS pipeline_compute_mma_ds( shared_storage.pipelines.compute_mma_ds, pipeline_compute_mma_ds_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; if (role == WarpRole::Mma) { pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; } pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( shared_storage.pipelines.mma_compute_dkdv, pipeline_mma_compute_dkdv_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); PipelineReduceTmaStore pipeline_reduce_tma_store; TmemAllocator tmem_allocator; pipeline_init_arrive_relaxed(size(ClusterShape{})); pipeline_load_mma_q.init_masks(ClusterShape{}); pipeline_load_mma_do.init_masks(ClusterShape{}); pipeline_mma_compute_s.init_masks(ClusterShape{}); pipeline_mma_compute_dp.init_masks(ClusterShape{}); pipeline_mma_reduce_dq.init_masks(ClusterShape{}); pipeline_compute_mma_p.init_masks(ClusterShape{}); pipeline_compute_mma_ds.init_masks(ClusterShape{}); pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; auto pipeline_load_mma_q_producer_state = make_producer_start_state(); auto pipeline_load_mma_do_producer_state = make_producer_start_state(); auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); pipeline_init_wait(size(ClusterShape{})); auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); auto [problem_shape, blk_offset] = apply_variable_length_offset( params.problem_shape, blk_coord ); int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; if constexpr (std::is_base_of_v, Mask>) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); } if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { return; } iter_count -= iter_start; if (iter_count <= 0) { epilogue_clear( blk_coord, blk_offset, problem_shape, params.mainloop, params.epilogue ); return; } if (role == WarpRole::Load) { warpgroup_reg_set(); load( blk_coord, blk_offset, problem_shape, iter_start, iter_count, params.mainloop, params.mainloop_params, shared_storage.tensors, pipeline_load_mma_q, pipeline_load_mma_q_producer_state, pipeline_load_mma_do, pipeline_load_mma_do_producer_state, pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state ); } else if (role == WarpRole::Mma) { warpgroup_reg_set(); tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); mma( blk_coord, problem_shape, iter_start, iter_count, params.mainloop, shared_storage.tensors, pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state ); } else if (role == WarpRole::Compute) { warpgroup_reg_set(); compute( blk_coord, blk_offset, problem_shape, iter_start, iter_count, params.mainloop, params.epilogue, shared_storage.tensors, pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier ).arrive_and_wait(); if (warp_idx % kNumComputeWarps == 0) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } else if (role == WarpRole::Reduce) { warpgroup_reg_set(); reduce( blk_coord, problem_shape, iter_start, iter_count, params.mainloop, params.mainloop_params, shared_storage.tensors, pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state ); pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); } else { warpgroup_reg_set(); /* no-op */ } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); } #endif } static dim3 get_block_shape() { dim3 block(MaxThreadsPerBlock, 1, 1); return block; } static dim3 get_grid_shape(Params const& params) { auto [Q, K, D, D_VO, HB] = params.problem_shape; auto [H, B] = HB; dim3 grid(ceil_div(K, TileShapeK{}), H, B); return grid; } }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/arch/simd_sm100.hpp" #include "cutlass/arch/arch.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include // for KERUTILS_ENABLE_SM100A #include "../collective/fmha_common.hpp" #include namespace cutlass::fmha::kernel { using namespace cutlass::fmha::collective; using namespace cute; template< class ProblemShape, class Element, class ElementAcc, class TileShape, class Mask > struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { using TileShapeQ = decltype(get<0>(TileShape{})); using TileShapeK = decltype(get<1>(TileShape{})); using TileShapeDQK = decltype(get<2>(TileShape{})); using TileShapeDVO = decltype(get<3>(TileShape{})); using TmemAllocator = cute::TMEM::Allocator1Sm; struct TmemAllocation { static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp static constexpr uint32_t kS = kDQ + 65536 * 16; static constexpr uint32_t kP = kS; static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; }; static_assert( static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem" ); enum class WarpRole { Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 }; static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; static constexpr int kNumComputeWarps = 8; static constexpr int kNumReduceWarps = 4; static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); } struct RegisterAllocation { static constexpr int kWarpgroup0 = 160-8; static constexpr int kWarpgroup1 = 128; static constexpr int kWarpgroup2 = 96; static constexpr int kReduce = kWarpgroup0; static constexpr int kCompute = kWarpgroup1; static constexpr int kMma = kWarpgroup2; static constexpr int kEmpty = kWarpgroup2; static constexpr int kLoad = kWarpgroup2; static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); }; using ArchTag = cutlass::arch::Sm100; using ClusterShape = Shape<_1, _1, _1>; using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; static constexpr int MinBlocksPerMultiprocessor = 1; static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; static constexpr int Alignment = 128 / sizeof_bits_v; static constexpr int kStages = 2; using TensorStrideContiguousK = Stride>; using TensorStrideContiguousMN = Stride<_1, int, Stride>; // compute S using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeQK = typename CollectiveMmaQK::TileShape; using TiledMmaQK = typename CollectiveMmaQK::TiledMma; // compute dP using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeDOV = typename CollectiveMmaDOV::TileShape; using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; // compute dV using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // needs to match ordering of S calculation Element, TensorStrideContiguousK, Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapePDO = typename CollectiveMmaPDO::TileShape; using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; // compute dK using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // somewhat arbitrary since we dump to smem, need to agree with the next one Element, TensorStrideContiguousK , Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; // compute dQ using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // somewhat arbitrary since we dump to smem, need to agree with the previous one Element, TensorStrideContiguousMN, Alignment, Element, TensorStrideContiguousMN, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TileShapeDSK = typename CollectiveMmaDSK::TileShape; using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; // pipelines are named Pipeline static constexpr int kStagesComputeSmem = 1; using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; using PipelineLoadComputeLSE = PipelineAsync<1>; using PipelineLoadComputeSumOdO = PipelineAsync<1>; using PipelineMmaComputeS = PipelineUmmaAsync<1>; using PipelineMmaComputeDP = PipelineUmmaAsync<1>; using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; static constexpr int kStagesReduceTmaStore = 2; using PipelineReduceTmaStore = PipelineTmaStore; struct PipelineStorage { alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; }; template static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { return composition(layout, make_tuple(_, _, _, make_layout(stages))); } using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); using SmemLayoutLSE = Layout>; using SmemLayoutSumOdO = Layout>; using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); using TileShapeDQ = _32; using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ >()); using SmemShapeDQ = Shape>; using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); struct TensorStorage { union { alignas(2048) cute::array> smem_k; alignas(2048) cute::array> smem_k_t; }; alignas(2048) cute::array> smem_v; union { alignas(2048) cute::array> smem_q; alignas(2048) cute::array> smem_q_t; }; union { alignas(2048) cute::array> smem_do; alignas(2048) cute::array> smem_do_t; }; union { alignas(2048) cute::array> smem_ds; alignas(2048) cute::array> smem_ds_t; }; union{ alignas(2048) cute::array> smem_p; alignas(2048) cute::array> smem_p_t; }; alignas(1024) cute::array> smem_dq; alignas(16) cute::array> smem_lse; alignas(16) cute::array> smem_sum_odo; }; static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); struct SharedStorage { TensorStorage tensors; PipelineStorage pipelines; uint32_t tmem_base_ptr; }; // this is tight enough that it won't work with sizeof due to padding for alignment static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); using TensorStride = TensorStrideContiguousK; // S D (H B) using RowTensorStride = Stride<_1, Stride>; // S (H B) struct MainloopArguments { const Element* ptr_q; TensorStride stride_q; const Element* ptr_k; TensorStride stride_k; const Element* ptr_v; TensorStride stride_v; const Element* ptr_do; TensorStride stride_do; const ElementAcc* ptr_lse; RowTensorStride stride_lse; const ElementAcc* ptr_sum_odo; RowTensorStride stride_sum_odo; ElementAcc* ptr_dq_acc; TensorStride stride_dq_acc; ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); }; using TMA_K = typename CollectiveMmaQK::Params::TMA_B; using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), SmemLayoutDQ{}(_, _, _0{}) )); struct MainloopParams { TMA_K tma_load_k; TMA_V tma_load_v; TMA_Q tma_load_q; TMA_DO tma_load_do; TMA_DQ tma_red_dq; }; struct EpilogueArguments { Element* ptr_dk; TensorStride stride_dk; Element* ptr_dv; TensorStride stride_dv; }; struct Arguments { ProblemShape problem_shape; MainloopArguments mainloop; EpilogueArguments epilogue; KernelHardwareInfo hw_info; }; struct Params { ProblemShape problem_shape; MainloopArguments mainloop; MainloopParams mainloop_params; EpilogueArguments epilogue; KernelHardwareInfo hw_info; }; static bool can_implement(Arguments const& args) { auto [Q, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { return false; } if (D % Alignment != 0 || D_VO % Alignment != 0) { return false; } return true; } static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { return Status::kSuccess; } static Params to_underlying_arguments(Arguments const& args, void*) { auto [Q_, K_, D, D_VO, HB] = args.problem_shape; int Q = Q_; int K = K_; if constexpr (is_variable_length_v) { Q = Q_.total_length; } if constexpr (is_variable_length_v) { K = K_.total_length; } auto params_kq = CollectiveMmaQK::to_underlying_arguments( make_shape(Q, K, D, HB), typename CollectiveMmaQK::Arguments { args.mainloop.ptr_q, args.mainloop.stride_q, args.mainloop.ptr_k, args.mainloop.stride_k, }, /*workspace=*/nullptr); auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( make_shape(Q, K, D_VO, HB), typename CollectiveMmaDOV::Arguments { args.mainloop.ptr_do, args.mainloop.stride_do, args.mainloop.ptr_v, args.mainloop.stride_v, }, /*workspace=*/nullptr); TMA_DQ tma_red_dq = make_tma_copy( SM90_TMA_REDUCE_ADD{}, make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), SmemLayoutDQ{}(_, _, _0{}) ); return Params{ args.problem_shape, args.mainloop, MainloopParams{ params_kq.tma_load_b, params_vdo.tma_load_b, params_kq.tma_load_a, params_vdo.tma_load_a, tma_red_dq }, args.epilogue, args.hw_info }; } template static CUTLASS_DEVICE auto quantize(T const& input) { constexpr int AlignmentS = 4; auto output = make_tensor(shape(input)); auto input_vec = recast>(input); auto output_vec = recast>(output); cutlass::NumericArrayConverter epilogue_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(input_vec); i++) { output_vec(i) = epilogue_op(input_vec(i)); } return output; } template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, TensorStorage& shared_tensors, PipelineLoadMmaQ& pipeline_load_mma_q, typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, PipelineLoadComputeLSE& pipeline_load_compute_lse, typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; using X = Underscore; uint16_t mcast_mask = 0; auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); auto tSTgK = cta_mma_kq.partition_B(gK); auto tSTgQ = cta_mma_kq.partition_A(gQ); auto tDPTgV = cta_mma_vdo.partition_B(gV); auto tDPTgDO = cta_mma_vdo.partition_A(gDO); auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); auto [tKgK_mkl, tKsK] = tma_partition( mainloop_params.tma_load_k, _0{}, make_layout(_1{}), group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); auto [tQgQ_mkl, tQsQ] = tma_partition( mainloop_params.tma_load_q, _0{}, make_layout(_1{}), group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); auto [tVgV_mkl, tVsV] = tma_partition( mainloop_params.tma_load_v, _0{}, make_layout(_1{}), group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); auto [tDOgDO_mkl, tDOsDO] = tma_partition( mainloop_params.tma_load_do, _0{}, make_layout(_1{}), group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); // set up lse and sum_odo auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); // load K if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), tKsK(_, _0{}) ); } // load Q if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQsQ(_, pipeline_load_mma_q_producer_state.index()) ); } ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); // load LSE // 32 threads loading kLoadPerThread * 32 values of 32b each int thread_idx = threadIdx.x % NumThreadsPerWarp; int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); for (int i = 0; i < kLoadPerThread; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_lse.begin() + smem_idx + i, &mLSE(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); // load V if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), tVsV(_, _0{}) ); } // load dO if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOsDO(_, pipeline_load_mma_do_producer_state.index()) ); } ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); for (int i = 0; i < kLoadPerThread; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_sum_odo.begin() + smem_idx + i, &mSumOdO(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; iter_count -= 1; iter_index += 1; while (iter_count > 0) { pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); // load Q if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQsQ(_, pipeline_load_mma_q_producer_state.index()) ); } ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); // load LSE smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; for (int i = 0; i < kLoadPerThread; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_lse.begin() + smem_idx + i, &mLSE(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); // load dO if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOsDO(_, pipeline_load_mma_do_producer_state.index()) ); } ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; for (int i = 0; i < kLoadPerThread; i++) { cutlass::arch::cp_async_zfill<4>( shared_tensors.smem_sum_odo.begin() + smem_idx + i, &mSumOdO(gmem_idx + i, blk_coord_batch), gmem_idx + i < Q ); } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; iter_count -= 1; iter_index += 1; } } template CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, PipelineLoadMmaQ& pipeline_load_mma_q, typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); TiledMmaQK tiled_mma_qk; TiledMmaDOV tiled_mma_dov; TiledMmaDSK tiled_mma_dsk; TiledMmaDSQ tiled_mma_dsq; TiledMmaPDO tiled_mma_pdo; tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); tSTtST.data() = TmemAllocation::kS; Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); tDPTtDPT.data() = TmemAllocation::kDP; Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); tDQtDQ.data() = TmemAllocation::kDQ; Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); tDKtDK.data() = TmemAllocation::kDK; Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); tDVtDV.data() = TmemAllocation::kDV; auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); // S = Q*K tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { cute::gemm(tiled_mma_qk, tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), tSTrK(_,_,k_block,_0{}), tSTtST); tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; } ++pipeline_load_mma_q_consumer_state; pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); ++pipeline_mma_compute_s_producer_state; pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); // dP = dO*V tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { cute::gemm(tiled_mma_dov, tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDPTrV(_,_,k_block,_0{}), tDPTtDPT); tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); ++pipeline_mma_compute_dp_producer_state; pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); // dV = P*dO CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { cute::gemm(tiled_mma_pdo, tDVrP(_,_,k_block,_0{}), tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDVtDV); tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); ++pipeline_compute_mma_p_consumer_state; pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); ++pipeline_load_mma_do_consumer_state; iter_count -= 1; // in tmem, S & P overlap // and dP and dQ overlap // so we need to acquire dQ and dP at the same time while (iter_count > 0) { pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); // S = Q*K tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { cute::gemm(tiled_mma_qk, tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), tSTrK(_,_,k_block,_0{}), tSTtST); tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; } ++pipeline_load_mma_q_consumer_state; pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); ++pipeline_mma_compute_s_producer_state; pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); // we need to acquire dP here, because tmem dQ == tmem dP pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); // dQ = dS*K tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { cute::gemm(tiled_mma_dsk, tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDQrKT(_,_,k_block,_0{}), tDQtDQ); tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); ++pipeline_mma_reduce_dq_producer_state; // dK = dS*Q CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { cute::gemm(tiled_mma_dsq, tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), tDKtDK); tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; } pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); ++pipeline_load_mma_q_release_state; pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); ++pipeline_compute_mma_ds_consumer_state; // we grab dq here, because in tmem dq == dp pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); // dP = dO*V tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { cute::gemm(tiled_mma_dov, tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDPTrV(_,_,k_block,_0{}), tDPTtDPT); tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); ++pipeline_mma_compute_dp_producer_state; pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); // dV = P*dO CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { cute::gemm(tiled_mma_pdo, tDVrP(_,_,k_block,_0{}), tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), tDVtDV); tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; } pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); ++pipeline_compute_mma_p_consumer_state; pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); ++pipeline_load_mma_do_consumer_state; iter_count -= 1; } // signal to the epilogue that dV is ready pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); ++pipeline_mma_compute_dkdv_producer_state; pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); // dK = dS*Q CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { cute::gemm(tiled_mma_dsq, tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), tDKtDK); tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; } // signal to epilgue that dK is ready pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); ++pipeline_mma_compute_dkdv_producer_state; // we've already acquired mma_reduce_dq in the loop // dQ = dS*K tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { cute::gemm(tiled_mma_dsk, tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), tDQrKT(_,_,k_block,_0{}), tDQtDQ); tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; } pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); ++pipeline_mma_reduce_dq_producer_state; pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); ++pipeline_load_mma_q_release_state; pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); ++pipeline_compute_mma_ds_consumer_state; } template CUTLASS_DEVICE void store( TensorG gmem, TensorR const& regs, TensorC const& coord, TensorShape const& tensor_shape) { Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, make_layout(make_shape(_1{}, Int{})), regs.layout() ); auto thr_copy = copy_op.get_slice(_0{}); Tensor quantized_regs = quantize(regs); Tensor tCr = thr_copy.partition_S(quantized_regs); Tensor tCg = thr_copy.partition_D(gmem); Tensor tPc = thr_copy.partition_D(preds); copy_if(copy_op, tPc, tCr, tCg); } template CUTLASS_DEVICE void epilogue_clear( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args) { auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDK = domain_offset( make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapeDSQ{})) ); auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { if (elem_less(cDK(i), select<1,2>(problem_shape))) { gDK(i) = Element(0); } } for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { if (elem_less(cDV(i), select<1,3>(problem_shape))) { gDV(i) = Element(0); } } } template CUTLASS_DEVICE void epilogue( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDKtDK.data() = TmemAllocation::kDK; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDK = domain_offset( make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapeDSQ{})) ); constexpr int kNumWarpgroups = kNumComputeWarps / 4; int dp_idx = threadIdx.x % 128; int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; auto split_wg = [&](auto const& t) { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); return p(_, _, make_coord(wg_idx, _)); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); return p(_, _, _, make_coord(wg_idx, _)); } }; auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); tDVtDV.data() = TmemAllocation::kDV; auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); // load tDVtDV cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); // store tDVgDV store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); ++pipeline_mma_compute_dkdv_consumer_state; pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); // load tDKtDK cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rDK); i++) { tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); } // store tDKgDK store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); ++pipeline_mma_compute_dkdv_consumer_state; } template CUTLASS_DEVICE void compute( BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, TensorStorage& shared_tensors, PipelineLoadComputeLSE& pipeline_load_compute_lse, typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; // in tmem, S & P overlap // and dP and dQ overlap // there are two compute wg's that cooperatively compute softmax // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); tSTtST.data() = TmemAllocation::kS; Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); tDPTtDPT.data() = TmemAllocation::kDP; Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); constexpr int kNumWarpgroups = kNumComputeWarps / 4; int dp_idx = threadIdx.x % 128; int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; auto tiled_t2r = make_tmem_copy(load_op, tSTtST); auto thread_t2r = tiled_t2r.get_slice(dp_idx); auto split_wg = [&](auto const& t) { if constexpr (decltype(size<1>(t))::value > 1) { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); return p(_, make_coord(wg_idx, _), _); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); return p(_, make_coord(wg_idx, _), _, _); } } else { if constexpr (decltype(rank(t))::value == 3) { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); return p(_, _, make_coord(wg_idx, _)); } else { auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); return p(_, _, _, make_coord(wg_idx, _)); } } }; Tensor tTR_cST_p = thread_t2r.partition_D(cST); Tensor tTR_cST = split_wg(tTR_cST_p); Tensor tTR_rST = make_tensor(shape(tTR_cST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); Tensor tTR_cDPT = split_wg(tTR_cDPT_p); Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); int last_iter = iter_count - 1 + iter_index; CUTLASS_PRAGMA_NO_UNROLL while (iter_count > 0) { // wait for S and P pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); // wait for LSE pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); auto dispatch_bool = [](bool b, auto fn) { if (b) { fn(cute::true_type{}); } else { fn(cute::false_type{}); } }; bool leading_causal_masking = false; if constexpr (std::is_base_of_v, Mask>) { leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); int kv_left = get<1>(blk_coord) * TileShapeK{}; int kv_right = kv_left + TileShapeK{} - 1; int q_left = iter_index * TileShapeQ{} + offset; int q_right = q_left + TileShapeQ{} - 1; leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); } bool trailing_residual_masking = false; if constexpr (std::is_base_of_v) { trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); } dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); if constexpr (decltype(is_masked_tile)::value) { Mask{}.apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); }, problem_shape); } ElementAcc log2_e = static_cast(M_LOG2E); float2 softmax_scale_log2_e; softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rST); i += 2) { float2 acc; float2 lse; float2 out; acc.x = tTR_rST(i); acc.y = tTR_rST(i + 1); lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); cute::fma(out, softmax_scale_log2_e, acc, lse); tTR_rST(i) = ::exp2f(out.x); tTR_rST(i+1) = ::exp2f(out.y); } auto tRT_rST = quantize(tTR_rST); Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) (_, _, _, pipeline_compute_mma_p_producer_state.index()); cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransformBarrier ).arrive_and_wait(); auto sP_pi = as_position_independent_swizzle_tensor(sP); auto thread_layout = make_ordered_layout( make_shape(_64{}, _32{}, _2{}, _2{}), make_stride(_3{}, _0{}, _1{}, _2{}) ); auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); auto sP_pi_slice = split_wg(sP_pi_slice_p); copy_aligned(tRT_rST, sP_pi_slice); }); // notify for P cutlass::arch::fence_view_async_shared(); pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); ++pipeline_compute_mma_p_producer_state; // release S pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); ++pipeline_mma_compute_s_consumer_state; // release LSE pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); ++pipeline_load_compute_lse_consumer_state; // wait for OdO pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); // wait for dP pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); // wait for dS // in principle, we could defer waiting for dS, and move in the freeing of dP // however, that would force us to keep dS in registers longer pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); // compute dS = dsoftmax(P, dP, sum_OdO) cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rDPT); i += 2) { float2 st; st.x = tTR_rST(i); st.y = tTR_rST(i+1); float2 dpt; dpt.x = tTR_rDPT(i); dpt.y = tTR_rDPT(i+1); float2 odo; odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); float2 dif; // sum odo is negated during preprocess cute::add(dif, dpt, odo); float2 out; cute::mul(out, dif, st); tTR_rDPT(i) = out.x; tTR_rDPT(i+1) = out.y; } auto tTR_rDST = quantize(tTR_rDPT); // release dP cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); ++pipeline_mma_compute_dp_consumer_state; Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) (_, _, _, pipeline_compute_mma_ds_producer_state.index()); auto thread_layout = make_ordered_layout( make_shape(_64{}, _32{}, _2{}, _2{}), make_stride(_3{}, _0{}, _1{}, _2{}) ); auto sDS_pi = as_position_independent_swizzle_tensor(sDS); auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); auto sDS_pi_slice = split_wg(sDS_pi_slice_p); copy_aligned(tTR_rDST, sDS_pi_slice); // notify for dS cutlass::arch::fence_view_async_shared(); pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); ++pipeline_compute_mma_ds_producer_state; // release OdO pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); ++pipeline_load_compute_sum_odo_consumer_state; iter_count -= 1; iter_index += 1; } epilogue( blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); } template CUTLASS_DEVICE void reduce( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, TensorStorage& shared_tensors, PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, PipelineReduceTmaStore& pipeline_reduce_tma_store, typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { using X = Underscore; auto [Q, K, D, D_VO, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; // must match TileShapeDQ auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); tDQtDQ.data() = TmemAllocation::kDQ; Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) (_, _, _, _0{}, blk_coord_batch); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); auto thread_t2r = tiled_t2r.get_slice(thread_idx); Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); Tensor tDQsDQ = block_tma.partition_S(sDQ); Tensor tDQcDQ = block_tma.partition_S(cDQ); Tensor tDQgDQ = block_tma.partition_D(gDQ); int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; while (iter_count > 0) { pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); // load dQ from tmem to rmem cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); ++pipeline_mma_reduce_dq_consumer_state; // we don't have enough smem to dump it all to smem, so we do it in stages CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<2>(tTR_cDQ); i++) { if (lane_predicate) { pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); } // wait in all threads for the acquire to complete cutlass::arch::NamedBarrier( kNumReduceWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransposeBarrier ).arrive_and_wait(); cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); // wait for the stores to all be visible to the TMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier( kNumReduceWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransposeBarrier ).arrive_and_wait(); if (lane_predicate) { // launch tma store copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); } ++pipeline_reduce_tma_store_producer_state; } iter_count -= 1; iter_index += 1; } } CUTLASS_DEVICE void operator()(Params const& params, char* smem) { #if defined(KERUTILS_ENABLE_SM100A) int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); if (role == WarpRole::Load && lane_predicate) { prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); } SharedStorage& shared_storage = *reinterpret_cast(smem); int initializing_warp = 0; typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; if (role == WarpRole::Load) { pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; } if (role == WarpRole::Mma) { pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; } pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); // Also loads K in the first iteration pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; pipeline_load_mma_q_params.initializing_warp = initializing_warp++; PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; if (role == WarpRole::Load) { pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; } if (role == WarpRole::Mma) { pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; } pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); // Also loads V in the first iteration pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; pipeline_load_mma_do_params.initializing_warp = initializing_warp++; PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; if (role == WarpRole::Load) { pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; } pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; PipelineLoadComputeLSE pipeline_load_compute_lse( shared_storage.pipelines.load_compute_lse, pipeline_load_compute_lse_params, /*barrier init*/ cute::true_type{}); typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; if (role == WarpRole::Load) { pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; } pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( shared_storage.pipelines.load_compute_sum_odo, pipeline_load_compute_sum_odo_params, /*barrier init*/ cute::true_type{}); typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; if (role == WarpRole::Mma) { pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; } pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; PipelineMmaComputeS pipeline_mma_compute_s( shared_storage.pipelines.mma_compute_s, pipeline_mma_compute_s_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; if (role == WarpRole::Mma) { pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; } pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; PipelineMmaComputeDP pipeline_mma_compute_dp( shared_storage.pipelines.mma_compute_dp, pipeline_mma_compute_dp_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; if (role == WarpRole::Mma) { pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; } if (role == WarpRole::Reduce) { pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; } pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; PipelineMmaReduceDQ pipeline_mma_reduce_dq( shared_storage.pipelines.mma_reduce_dq, pipeline_mma_reduce_dq_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; if (role == WarpRole::Mma) { pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; } if (role == WarpRole::Compute) { pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; } pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_compute_mma_p_params.consumer_arv_count = 1; pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; PipelineComputeMmaP pipeline_compute_mma_p( shared_storage.pipelines.compute_mma_p, pipeline_compute_mma_p_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; if (role == WarpRole::Mma) { pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; } if (role == WarpRole::Compute) { pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; } pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_compute_mma_ds_params.consumer_arv_count = 1; pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; PipelineComputeMmaDS pipeline_compute_mma_ds( shared_storage.pipelines.compute_mma_ds, pipeline_compute_mma_ds_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; if (role == WarpRole::Mma) { pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; } if (role == WarpRole::Compute) { pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; } pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( shared_storage.pipelines.mma_compute_dkdv, pipeline_mma_compute_dkdv_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); PipelineReduceTmaStore pipeline_reduce_tma_store; TmemAllocator tmem_allocator; pipeline_init_arrive_relaxed(size(ClusterShape{})); pipeline_load_mma_q.init_masks(ClusterShape{}); pipeline_load_mma_do.init_masks(ClusterShape{}); pipeline_mma_compute_s.init_masks(ClusterShape{}); pipeline_mma_compute_dp.init_masks(ClusterShape{}); pipeline_mma_reduce_dq.init_masks(ClusterShape{}); pipeline_compute_mma_p.init_masks(ClusterShape{}); pipeline_compute_mma_ds.init_masks(ClusterShape{}); pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; auto pipeline_load_mma_q_producer_state = make_producer_start_state(); auto pipeline_load_mma_do_producer_state = make_producer_start_state(); auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); pipeline_init_wait(size(ClusterShape{})); auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); auto [problem_shape, blk_offset] = apply_variable_length_offset( params.problem_shape, blk_coord ); int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; if constexpr (std::is_base_of_v, Mask>) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); } if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { return; } iter_count -= iter_start; if (iter_count <= 0) { epilogue_clear( blk_coord, blk_offset, problem_shape, params.mainloop, params.epilogue ); return; } if (role == WarpRole::Load) { warpgroup_reg_set(); load( blk_coord, blk_offset, problem_shape, iter_start, iter_count, params.mainloop, params.mainloop_params, shared_storage.tensors, pipeline_load_mma_q, pipeline_load_mma_q_producer_state, pipeline_load_mma_do, pipeline_load_mma_do_producer_state, pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state ); } else if (role == WarpRole::Mma) { warpgroup_reg_set(); tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); mma( blk_coord, problem_shape, iter_start, iter_count, params.mainloop, shared_storage.tensors, pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state ); } else if (role == WarpRole::Compute) { warpgroup_reg_set(); compute( blk_coord, blk_offset, problem_shape, iter_start, iter_count, params.mainloop, params.epilogue, shared_storage.tensors, pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier ).arrive_and_wait(); if (warp_idx % kNumComputeWarps == 0) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } else if (role == WarpRole::Reduce) { warpgroup_reg_set(); reduce( blk_coord, problem_shape, iter_start, iter_count, params.mainloop, params.mainloop_params, shared_storage.tensors, pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state ); pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); } else { warpgroup_reg_set(); /* no-op */ } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); } #endif } static dim3 get_block_shape() { dim3 block(MaxThreadsPerBlock, 1, 1); return block; } static dim3 get_grid_shape(Params const& params) { auto [Q, K, D, D_VO, HB] = params.problem_shape; auto [H, B] = HB; dim3 grid(ceil_div(K, TileShapeK{}), H, B); return grid; } }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cute/layout.hpp" #include "cutlass/arch/arch.h" #include "cutlass/kernel_hardware_info.h" #include "cutlass/pipeline/pipeline.hpp" #include "cute/arch/tmem_allocator_sm100.hpp" #include // for KERUTILS_ENABLE_SM100A #include "../kernel/fmha_options.hpp" #include "../kernel/fmha_tile_scheduler.hpp" #include "../kernel/fmha_causal_tile_scheduler.hpp" #include "../collective/fmha_fusion.hpp" #include "../collective/fmha_common.hpp" namespace cutlass::fmha::kernel { using namespace cute; using namespace cutlass::fmha::collective; struct Sm100FmhaCtxKernelWarpspecializedSchedule { enum class WarpRole { Softmax0, Softmax1, Correction, MMA, Load, Epilogue, Empty }; static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { int wg_idx = warp_idx / 4; // warp_idx if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 if (warp_idx == 12) return WarpRole::MMA; // 12 if (warp_idx == 13) return WarpRole::Load; // 13 if (warp_idx == 14) return WarpRole::Epilogue; // 14 return WarpRole::Empty; // 15 } static const int NumWarpsSoftmax = 4; static const int NumWarpsCorrection = 4; static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; static const bool kDebugUsingPrintf = false; static const int NumRegsSoftmax = 192; static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); static const int NumRegsEmpty = 24; static const int NumWarps = 16; }; struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { enum class WarpRole { Softmax0, Softmax1, Correction, MMA, Load, Epilogue, Empty }; static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { int wg_idx = warp_idx / 4; // warp_idx if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 if (warp_idx == 12) return WarpRole::MMA; // 12 if (warp_idx == 13) return WarpRole::Load; // 13 if (warp_idx == 14) return WarpRole::Epilogue; // 14 return WarpRole::Empty; // 15 } static const int NumWarpsSoftmax = 4; static const int NumWarpsCorrection = 4; static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; static const bool kDebugUsingPrintf = false; static const int NumRegsSoftmax = 184; static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); static const int NumRegsEmpty = 24; static const int NumWarps = 16; }; template< class ProblemShapeIn, class CollectiveMainloop, class CollectiveEpilogue, class TileScheduler, class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule > struct Sm100FmhaFwdKernelTmaWarpspecialized { using TileShape = typename CollectiveMainloop::TileShape; using ProblemShape = ProblemShapeIn; using WarpRole = typename KernelSchedule::WarpRole; constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { return KernelSchedule::warp_idx_to_WarpRole(warp_idx); } static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; static const int NumRegsOther = KernelSchedule::NumRegsOther; static const int NumRegsEmpty = 24; static const int NumWarps = KernelSchedule::NumWarps; static constexpr bool IsMla = std::is_same_v; using ClusterShape = typename CollectiveMainloop::ClusterShape; using TmemAllocator = cute::TMEM::Allocator1Sm; struct SharedStorage { using UnionType = union { typename CollectiveMainloop::TensorStorage mainloop; typename CollectiveEpilogue::TensorStorage epilogue; }; using StructType = struct { typename CollectiveMainloop::TensorStorage mainloop; typename CollectiveEpilogue::TensorStorage epilogue; }; static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; using MainloopEpilogueStorage = std::conditional_t, StructType>, UnionType>; MainloopEpilogueStorage mainloop_epilogue; struct PipelineStorage { alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; } pipelines; uint32_t tmem_base_ptr; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); struct Arguments { ProblemShape problem_shape; typename CollectiveMainloop::Arguments mainloop; typename CollectiveEpilogue::Arguments epilogue; cutlass::KernelHardwareInfo hw_info; }; struct Params { ProblemShape problem_shape; typename CollectiveMainloop::Params mainloop; typename CollectiveEpilogue::Params epilogue; typename TileScheduler::Params tile_scheduler; }; static const int MinBlocksPerMultiprocessor = 1; static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; using ArchTag = cutlass::arch::Sm100; static size_t get_workspace_size(Arguments const& args) { return 0; } static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { return cutlass::Status::kSuccess; } static bool can_implement(Arguments const& args) { return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); } static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.tile_scheduler); } static dim3 get_block_shape() { dim3 block(MaxThreadsPerBlock, 1, 1); return block; } static Params to_underlying_arguments(Arguments const& args, void* workspace) { return Params{ args.problem_shape, CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) }; } CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { return apply_variable_length(params.problem_shape, batch_idx); } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { #if defined(KERUTILS_ENABLE_SM100A) TileScheduler tile_scheduler{params.tile_scheduler}; int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_WarpRole(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); if (role == WarpRole::Load && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); } if (role == WarpRole::Epilogue && lane_predicate) { CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } SharedStorage& shared_storage = *reinterpret_cast(smem); auto get_epilogue_storage = [&]() { if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); } else { return &shared_storage.mainloop_epilogue.epilogue; } }; typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; if (role == WarpRole::Load) { pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; } if (role == WarpRole::MMA) { pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; } pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; typename CollectiveMainloop::PipelineQ pipeline_load_q( shared_storage.pipelines.load_q, pipeline_load_q_params, ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; if (role == WarpRole::Load) { pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; } if (role == WarpRole::MMA) { pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; } pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; typename CollectiveMainloop::PipelineKV pipeline_load_kv( shared_storage.pipelines.load_kv, pipeline_load_kv_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; if (role == WarpRole::MMA) { pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; } if (role == WarpRole::Softmax0) { pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; } pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineS pipeline_mma_s0( shared_storage.pipelines.mma_s0, pipeline_mma_s0_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; if (role == WarpRole::MMA) { pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; } if (role == WarpRole::Softmax1) { pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; } pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineS pipeline_mma_s1( shared_storage.pipelines.mma_s1, pipeline_mma_s1_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; if (role == WarpRole::Softmax0) { pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; } if (role == WarpRole::Correction) { pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; } pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineC pipeline_s0_corr( shared_storage.pipelines.s0_corr, pipeline_s0_corr_params, /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; if (role == WarpRole::Softmax1) { pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; } if (role == WarpRole::Correction) { pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; } pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineC pipeline_s1_corr( shared_storage.pipelines.s1_corr, pipeline_s1_corr_params, /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; if (role == WarpRole::MMA) { pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; } if (role == WarpRole::Correction) { pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; } pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineO pipeline_mma_corr( shared_storage.pipelines.mma_corr, pipeline_mma_corr_params, ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; if (role == WarpRole::Correction) { pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; } if (role == WarpRole::Epilogue) { pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; } pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineE pipeline_corr_epi( shared_storage.pipelines.corr_epi, pipeline_corr_epi_params, /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::OrderBarrierSoftmax order_s01( shared_storage.pipelines.order_s01, params_order_s01); TmemAllocator tmem_allocator; __syncthreads(); pipeline_load_q.init_masks(ClusterShape{}); pipeline_load_kv.init_masks(ClusterShape{}); pipeline_mma_s0.init_masks(ClusterShape{}); pipeline_mma_s1.init_masks(ClusterShape{}); pipeline_mma_corr.init_masks(ClusterShape{}); typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); CollectiveMainloop mainloop; CollectiveEpilogue epilogue{params.epilogue}; if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { warpgroup_reg_set(); CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { continue; } if (get<1>(logical_problem_shape) == 0) { continue; } bool is_softmax_0 = role == WarpRole::Softmax0; mainloop.softmax( is_softmax_0 ? 0 : 1, blk_coord, params.mainloop, logical_problem_shape, is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, order_s01 ); } } else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); bool has_valid = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { continue; } has_valid = true; if (get<1>(logical_problem_shape) == 0) { mainloop.correction_empty( blk_coord, params.mainloop, logical_problem_shape, params.problem_shape, epilogue_storage, pipeline_corr_epi, pipeline_corr_epi_producer_state, epilogue ); continue; } mainloop.correction( blk_coord, params.mainloop, logical_problem_shape, params.problem_shape, epilogue_storage, pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state, pipeline_corr_epi, pipeline_corr_epi_producer_state, epilogue ); } if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); if (has_valid) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } } else if (role == WarpRole::MMA) { warpgroup_reg_set(); bool allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { continue; } if (!allocated) { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); allocated = true; } if (get<1>(logical_problem_shape) == 0) { continue; } mainloop.mma( blk_coord, params.mainloop, logical_problem_shape, shared_storage.mainloop_epilogue.mainloop, pipeline_load_q, pipeline_load_q_consumer_state, pipeline_load_kv, pipeline_load_kv_consumer_state, pipeline_mma_s0, pipeline_mma_s0_producer_state, pipeline_mma_s1, pipeline_mma_s1_producer_state, pipeline_mma_corr, pipeline_mma_corr_producer_state ); } } else if (role == WarpRole::Load) { warpgroup_reg_set(); if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { continue; } if (get<1>(logical_problem_shape) == 0) { continue; } mainloop.load( blk_coord, logical_problem_shape, params.mainloop, params.problem_shape, shared_storage.mainloop_epilogue.mainloop, pipeline_load_q, pipeline_load_q_producer_state, pipeline_load_kv, pipeline_load_kv_producer_state ); } } else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); bool has_valid = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { continue; } has_valid = true; epilogue.store( blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, epilogue_storage, pipeline_corr_epi, pipeline_corr_epi_consumer_state ); } static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { if(has_valid) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } } else if (role == WarpRole::Empty) { warpgroup_reg_set(); /* no-op, donate regs and exit */ } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); } #endif } }; } // namespace cutlass::fmha::kernel ================================================ FILE: csrc/sm100/prefill/sparse/common_subroutine.h ================================================ #pragma once #include #include namespace sm100 { /* Load K/V indices from global memory, and generate validity mask Each thread loads 8 indices Should be called by lanes 0 ~ (BLOCK_TOPK/8) */ CUTE_DEVICE char load_indices_and_generate_mask( int lane_idx, int* gIndices, int s_kv, int abs_pos_start, int topk_length ) { int indices[8]; KU_LDG_256( gIndices + lane_idx*8, indices, ".nc", "no_allocate", "evict_normal", "256B" ); auto is_valid = [&](int rel_pos_in_lane, int index) -> char { int abs_pos = abs_pos_start + lane_idx*8 + rel_pos_in_lane; return index >= 0 && index < s_kv && abs_pos < topk_length; }; char is_ks_valid_mask = \ is_valid(7, indices[7]) << 7 | is_valid(6, indices[6]) << 6 | is_valid(5, indices[5]) << 5 | is_valid(4, indices[4]) << 4 | is_valid(3, indices[3]) << 3 | is_valid(2, indices[2]) << 2 | is_valid(1, indices[1]) << 1 | is_valid(0, indices[0]) << 0; return is_ks_valid_mask; } /* Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout: N N --- (topk) +-------+-------+ | | | 32 | Warp0 | Warp2 | | | | +-------+-------+ | | | 32 | Warp1 | Warp3 | | | | +-------+-------+ | (head) where N = NUM_ELEMS_PER_THREAD */ template< int NUM_ELEMS_PER_THREAD, int TMEM_COL_START, int BARRIER_WARP02_SYNC_ID, int BARRIER_WARP13_SYNC_ID, bool STORE_BACK_P > CUTE_DEVICE void retrieve_mask_and_reduce_p( char* k_validness_base, int local_warp_idx, int lane_idx, auto slot_bar_P_empty_arrival, float p_exchange_buf[4][32*NUM_ELEMS_PER_THREAD], float p[NUM_ELEMS_PER_THREAD] ) { using namespace cute; using cutlass::arch::NamedBarrier; static_assert(BARRIER_WARP13_SYNC_ID == BARRIER_WARP02_SYNC_ID+1); float p_peer[NUM_ELEMS_PER_THREAD]; if (local_warp_idx < 2) { ku::tmem_ld_32dp32bNx(TMEM_COL_START, p); ku::tmem_ld_32dp32bNx(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p_peer); } else { ku::tmem_ld_32dp32bNx(TMEM_COL_START, p_peer); ku::tmem_ld_32dp32bNx(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p); } cutlass::arch::fence_view_async_tmem_load(); ku::tcgen05_before_thread_sync(); slot_bar_P_empty_arrival(); // Mask invalid tokens // We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load static_assert(NUM_ELEMS_PER_THREAD == 32); uint32_t is_k_valid = *(uint32_t*)(k_validness_base + (local_warp_idx>=2?NUM_ELEMS_PER_THREAD/8:0)); CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD; i += 1) { if (!(is_k_valid >> i & 1)) p[i] = -CUDART_INF_F; } // Reduce P within the cluster { // Store // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { ku::st_shared(&p_exchange_buf[local_warp_idx^2][i*32*4 + lane_idx*4], *(float4*)(p_peer + i*4)); } NamedBarrier::arrive_and_wait(64, BARRIER_WARP02_SYNC_ID + (local_warp_idx&1)); CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { float2 t[2]; *(float4*)t = *(float4*)(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4]); float2* cur_p = (float2*)(p + i*4); cur_p[0] = ku::float2_add(cur_p[0], t[0]); cur_p[1] = ku::float2_add(cur_p[1], t[1]); } } if constexpr (STORE_BACK_P) { CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { ku::st_shared(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4], *(float4*)(p+i*4)); } } } /* Rescale O in Tensor Memory. O should occupy 128 rows x (D_V/2) columns in Tensor Memory. */ template< int D_V, int CHUNK_SIZE, int TMEM_COL_START > CUTE_DEVICE void rescale_O( float scale_factor ) { float2 scale_factor_float2 = {scale_factor, scale_factor}; float2 o[CHUNK_SIZE/2]; CUTE_UNROLL for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { // Load O ku::tmem_ld_32dp32bNx(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_load(); // Mult for (int i = 0; i < CHUNK_SIZE/2; ++i) { o[i] = ku::float2_mul(o[i], scale_factor_float2); } // Store O ku::tmem_st_32dp32bNx(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_store(); } } template CUTE_DEVICE float get_max( float p[NUM_ELEMS_PER_THREAD] ) { float local_max = -CUDART_INF_F; CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD; ++i) { local_max = max(local_max, p[i]); } return local_max; } /* Calculate s := exp2f(p*scale - new_max) and its sum */ template CUTE_DEVICE float get_s_from_p( nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2], float p[NUM_ELEMS_PER_THREAD], float scale, float new_max ) { float2 cur_sum = float2 {0.0f, 0.0f}; float2 neg_new_max_float2 = float2 {-new_max, -new_max}; float2 scale_float2 = float2 {scale, scale}; CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/2; i += 1) { float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale_float2, neg_new_max_float2); d.x = exp2f(d.x); d.y = exp2f(d.y); cur_sum = ku::float2_add(cur_sum, d); s[i] = __float22bfloat162_rn(d); } return cur_sum.x + cur_sum.y; } } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head128/config.h ================================================ #pragma once #include #include #include #include "params.h" #include "defines.h" namespace sm100::fwd::head128 { using namespace cute; template< typename Shape_Q, typename TMA_Q, typename Shape_O, typename TMA_O > struct TmaParams { Shape_Q shape_Q; TMA_Q tma_Q; Shape_O shape_O; TMA_O tma_O; CUtensorMap tensor_map_kv; }; struct float2x2 { float2 lo, hi; }; template struct KernelTemplate { static constexpr int D_Q = D_QK; static constexpr int D_K = D_QK; static constexpr int D_V = 512; static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan static constexpr int B_H = 128; // For 2 CTAs static constexpr int B_TOPK = 128; // For 2 CTAs static constexpr int NUM_BUFS = 2; static constexpr int NUM_THREADS = 256 + 128 + 128; // 128 scale & exp threads, 128x2 TMA threads, 32 UTCMMA threads static constexpr int D_tQ = 384, NUM_tQ_TILES = D_tQ / 64; static constexpr int D_sQ = D_QK-D_tQ, NUM_sQ_TILES = D_sQ / 64; static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q); // Tensor memory columns struct tmem_cols { // 0 ~ 256: output // 256 ~ 320: P // 320 ~ 512: Q[D_QK-D_tQ:] static constexpr int o = 0; static constexpr int p = 256; static constexpr int q = 512 - D_tQ/2; static_assert(p+64 <= q); }; template using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutO = SmemLayoutOTiles<8>; template using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutV = decltype(coalesce(tile_to_shape( UMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, Step<_2, _1>{} ), Shape<_1, _1>{})); template using SmemLayoutSTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_INTER_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); struct SharedMemoryPlan { union { array_aligned>> q_full; struct { array_aligned>> sq; array_aligned> v; // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q static_assert(cosize_v> <= cosize_v> + cosize_v); array_aligned>> k; } s; array_aligned> o; } u; array_aligned>> s; float p[(B_H/2)*B_TOPK]; char is_k_valid[NUM_BUFS][B_TOPK/8]; transac_bar_t bar_prologue_q, bar_prologue_utccp; transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free) transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free) transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS]; transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready transac_bar_t bar_p_free[NUM_BUFS]; transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; array_aligned tmem_start_addr; float rowwise_max_buf[128], rowwise_li_buf[128]; }; using TiledMMA_P_tQ = decltype(make_tiled_mma( SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} )); using TiledMMA_P_sQ = decltype(make_tiled_mma( SM100_MMA_F16BF16_2x1SM_SS_NOELECT{} )); using TiledMMA_O = decltype(make_tiled_mma( SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, Layout>{}, Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] )); template static __device__ void sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams ¶ms, const TmaParams &tma_params); }; } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd::head128 { template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd::head128 { template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh ================================================ #pragma once #include "phase1.h" #include #include #include #include #include #include #include "params.h" #include "utils.h" #include "sm100/helpers.h" #include "config.h" namespace sm100::fwd::head128 { using namespace cute; CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) { int32x8_t val; asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) : "l"(src_ptr) ); return val; } /* Pipeline Overview: | Copy | MMA | Scale & Exp | K0 V0 P0 = QK0^T K1 S0 = exp(P0) scale(O) w.r.t P0 P1 = QK1^T K2 S1 = exp(P1) O += S0V0 V1 scale(O) w.r.t P1 P2 = QK2^T K3 S2 = exp(P2) O += S1V1 V2 scale(O) w.r.t P2 P3 = QK3^T K4 S3 = exp(P3) O += S2V2 V3 scale(O) w.r.t P3 ... O += S(n-3)V(n-3) V(n-2) scale(O) w.r.t P(n-2) P(n-1) = QK(n-1)^T S(n-1) = exp(P(n-1)) O += S(n-2)V(n-2) V(n-1) scale(O) w.r.t P(n-1) O += S(n-1)V(n-1) */ template template __device__ void KernelTemplate::sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams ¶ms, const TmaParams &tma_params) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int cta_idx = blockIdx.x % 2; const int s_q_idx = blockIdx.x / 2; const int warp_idx = cutlass::canonical_warp_idx_sync(); const int lane_idx = threadIdx.x % 32; const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk; const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const int idx_in_warpgroup = threadIdx.x % 128; // Prefetch TMA descriptors if (threadIdx.x == 0) { cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv)); } // Define shared tensors extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles{}); int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] // Allocate tmem tensors TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{}; TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{}; TiledMMA tiled_mma_O = TiledMMA_O{}; Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape, Int>{}); Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P_tQ, Shape, Int>{}) ); Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); tP.data().get() = tmem_cols::p; tQr.data().get() = tmem_cols::q; tO.data().get() = tmem_cols::o; if (warp_idx == 0) { if (elect_one_sync()) { // Initialize barriers plan.bar_prologue_q.init(1); plan.bar_prologue_utccp.init(1); CUTE_UNROLL for (int i = 0; i < NUM_BUFS; ++i) { plan.bar_qk_part_done[i].init(1); plan.bar_qk_done[i].init(1); plan.bar_sv_part_done[i].init(1); plan.bar_sv_done[i].init(1); plan.bar_k_part0_ready[i].init(1); plan.bar_k_part1_ready[i].init(1); plan.bar_v_part0_ready[i].init(1); plan.bar_v_part1_ready[i].init(1); plan.bar_p_free[i].init(128*2); plan.bar_so_ready[i].init(128*2); plan.bar_k_valid_ready[i].init(16); plan.bar_k_valid_free[i].init(128); } fence_barrier_init(); } } cute::cluster_sync(); // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0 if (warp_idx == 0) { if (elect_one_sync()) { // Copy Q Tensor gQ = flat_divide( tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), Tile>{} )(_, cta_idx, _); ku::launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST); } // Initialize TMEM cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data()); TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); cute::TMEM::Allocator2Sm().release_allocation_lock(); } __syncthreads(); // Wait for TMEM allocation if (warpgroup_idx == 0) { cutlass::arch::warpgroup_reg_alloc<144>(); // Scale & Exp warps // The following three numbers are // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V) // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi)) // - real_mi: real max logits, i.e. real_mi := max(Pi*scale) // where Pi is the i-th row of P, P := QK^T // mi and real_mi are always consistent within the two threads that // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update float mi = MAX_INIT_VAL; float li = 0.0f; float real_mi = -CUDART_INF_F; const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8); float* sP_base = plan.p + idx_in_warpgroup%64*4 + (idx_in_warpgroup/64)*((B_H/2)*(B_TOPK/2)); CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { // Wait for P plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); // Load P float2 p[(B_TOPK/2)/2]; ku::tmem_ld_32dp32bNx(tmem_cols::p, p); cutlass::arch::fence_view_async_tmem_load(); ku::tcgen05_before_thread_sync(); plan.bar_p_free[k%NUM_BUFS].arrive(0u); // Mask plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); // The following code enables NVCC to use R2P instruction // Although we perform 2x LDS.32 instructions here, don't worry, NVCC will // convert them to one LDS.64 instruction. However, if we write LDS.64 // here, NVCC won't use R2P. uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0)); uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4); float* p_float = (float*)p; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/2; i += 1) { if (!(is_k_valid_lo >> i & 1)) p_float[i] = -CUDART_INF_F; } CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/2; i += 1) { if (!(is_k_valid_hi >> i & 1)) p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F; } // Get rowwise max of Pi float cur_pi_max = -CUDART_INF_F; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2); i += 1) { cur_pi_max = max(cur_pi_max, p_float[i]); } cur_pi_max *= params.sm_scale_div_log2; plan.bar_k_valid_free[k%NUM_BUFS].arrive(); NamedBarrier::arrive_and_wait(128, 0); // Wait for rowwise_max_buf and sP to be ready plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); real_mi = max(real_mi, cur_pi_max); bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); // By this point: // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) // - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127 // Calc scale factor, and scale li float new_max, scale_for_old; if (!should_scale_o) { // Don't scale O scale_for_old = 1.0f; new_max = mi; } else { new_max = max(cur_pi_max, mi); scale_for_old = exp2f(mi - new_max); } mi = new_max; // mi is still identical within each row li *= scale_for_old; // Calculate S __nv_bfloat162 s[(B_TOPK/2)/2]; float2 neg_new_max = float2 {-new_max, -new_max}; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/2; i += 1) { float2 d = ku::float2_fma(p[i], scale, neg_new_max); d.x = exp2f(d.x); d.y = exp2f(d.y); li += d.x + d.y; // NOTE: Theoretically we could use FFMA2 here but actually this is faster... s[i] = __float22bfloat162_rn(d); } // Wait for last SV gemm, write S if (k > 0) { plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } CUTE_UNROLL for (int i = 0; i < B_TOPK/2/8; i += 1) { sS_base[64*i] = *(uint128_t*)(s + i*4); } // Scale O if (k > 0 && should_scale_o) { float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE: We have waited for last SV gemm before ku::tcgen05_after_thread_sync(); static constexpr int CHUNK_SIZE = 32; float2 o[CHUNK_SIZE/2]; CUTE_UNROLL for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { // Load O ku::tmem_ld_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_load(); // Mult for (int i = 0; i < CHUNK_SIZE/2; ++i) { o[i] = ku::float2_mul(o[i], scale_for_old_float2); } // Store O ku::tmem_st_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_store(); } ku::tcgen05_before_thread_sync(); } fence_view_async_shared(); plan.bar_so_ready[k%NUM_BUFS].arrive(0u); } // Epilogue if (real_mi == -CUDART_INF_F) { // real_mi == -CUDART_INF_F <=> No valid TopK indices // We set li to 0 to fit the definition that li := exp(x[i] - mi) li = 0.0f; mi = -CUDART_INF_F; } // Exchange li plan.rowwise_li_buf[idx_in_warpgroup] = li; NamedBarrier::arrive_and_wait(128, 0); li += plan.rowwise_li_buf[idx_in_warpgroup^64]; // Store mi and li if (idx_in_warpgroup < 64) { int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup; float cur_lse = logf(li) + mi*CUDART_LN2_F; cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; params.max_logits[global_index] = real_mi*CUDART_LN2_F; params.lse[global_index] = cur_lse; } // Wait for the last GEMM plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); // Store O float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + cta_idx*B_H/2 + (idx_in_warpgroup%64))*CUDART_L2E_F; float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi)); Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); constexpr int B_EPI = 64; Tensor tma_gO = flat_divide( tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx), Shape, Int>{} )(_, _, cta_idx, _); Tensor sO_divided = flat_divide( sO, Shape, Int>{} )(_, _, _0{}, _); auto thr_tma = tma_params.tma_O.get_slice(_0{}); float2 o[B_EPI/2]; bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld if (!have_valid_indices) { // If there are no valid indices, we set o[i] to 0 and don't load from TMEM CUTE_UNROLL for (int i = 0; i < B_EPI/2; ++i) o[i].x = o[i].y = 0.0f; output_scale = 1.0f; } float2 output_scale_float2 = make_float2(output_scale, output_scale); CUTE_UNROLL for (int k = 0; k < (D_V/2)/B_EPI; ++k) { // Load O from tO if (have_valid_indices) { ku::tmem_ld_32dp32bNx(tmem_cols::o + k*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); } // Convert and store CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) { __nv_bfloat162 o_bf16[4]; CUTE_UNROLL for (int j = 0; j < 4; ++j) { float2 d = ku::float2_mul(o[i*4+j], output_scale_float2); o_bf16[j] = __float22bfloat162_rn(d); } int smem_row = idx_in_warpgroup % 64; int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8; *(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16); } // Sync fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, 0); if (warp_idx == 0 && elect_one_sync()) { cute::copy( tma_params.tma_O, thr_tma.partition_S(sO_divided(_, _, k)), thr_tma.partition_D(tma_gO(_, _, k)) ); } if (warp_idx == 1 && elect_one_sync()) { int k2 = k + (D_V/B_EPI/2); cute::copy( tma_params.tma_O, thr_tma.partition_S(sO_divided(_, _, k2)), thr_tma.partition_D(tma_gO(_, _, k2)) ); } } if (warp_idx == 0) { cute::TMEM::Allocator2Sm().free(0, 512); } } else if (warpgroup_idx == 1) { // Producer warp for K cutlass::arch::warpgroup_reg_dealloc<96>(); int warp_idx = cutlass::canonical_warp_idx_sync() - 4; constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS; if (elect_one_sync()) { bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64; CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { int4 indices[NUM_LOCAL_ROWS_PER_WARP]; int max_indices = -1, min_indices = params.s_kv; CUTE_UNROLL for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx); max_indices = max(max_indices, int4_max(indices[local_row])); min_indices = min(min_indices, int4_min(indices[local_row])); } bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1; bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS; auto load_part_ki = [&](transac_bar_t &bar, int local_col_start, int local_col_end) { CUTE_UNROLL for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { CUTE_UNROLL for (int local_col = local_col_start; local_col < local_col_end; ++local_col) ku::tma_gather4_cta_group_2( &(tma_params.tensor_map_kv), bar, sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64), local_col*64, indices[local_row], (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } }; int cur_buf = k%NUM_BUFS; if (k > 0) { plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } if (!should_skip_tma) { load_part_ki(plan.bar_k_part0_ready[cur_buf], 0, D_sQ/64); } else { // NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid. // NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs. // NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases plan.bar_k_part0_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_sQ*sizeof(bf16), 1u); } if (k > 0) { plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } if (!should_skip_tma) { load_part_ki(plan.bar_k_part1_ready[cur_buf], D_sQ/64, D_K/64); } else { plan.bar_k_part1_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_tQ*sizeof(bf16), 1u); } } } } else if (warpgroup_idx == 2) { // Producer warps for V cutlass::arch::warpgroup_reg_dealloc<96>(); int warp_idx = cutlass::canonical_warp_idx_sync() - 8; constexpr int NUM_WARPS = 4; if (elect_one_sync()) { // Wait for UTCCP plan.bar_prologue_utccp.wait(0); bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64; CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { auto load_part_vi = [&](transac_bar_t &bar, int local_row_start, int local_row_end) { CUTE_UNROLL for (int local_row = local_row_start; local_row < local_row_end; ++local_row) { int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); CUTE_UNROLL for (int local_col = 0; local_col < (D_V/2)/64; ++local_col) ku::tma_gather4_cta_group_2( &(tma_params.tensor_map_kv), bar, sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), local_col*64 + (cta_idx?256:0), token_idxs, (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } }; int cur_buf = k%NUM_BUFS; if (k > 0) { plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } load_part_vi(plan.bar_v_part0_ready[cur_buf], 0, (B_TOPK/2)/4/NUM_WARPS); if (k > 0) { plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } load_part_vi(plan.bar_v_part1_ready[cur_buf], (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS); } } } else { cutlass::arch::warpgroup_reg_alloc<168>(); // MMA warp if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) { // S -> T copy for Q UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ), tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64>>{} ) ) ); plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); plan.bar_prologue_q.wait(0); ku::tcgen05_after_thread_sync(); CUTE_UNROLL for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) { // A tile is 64 rows * 64 cols (128B) CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) { // A subtile is 64 rows * 8 cols (128b) SM100_UTCCP_2x64dp128bitlw0213_2cta::copy( sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16), // Remember that 4 LSBs are not included tmem_cols::q + tile_idx*32 + subtile_idx*4 ); } } ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2); CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks+1; ++k) { if (k < num_k_blocks) { // Pi = QKi^T int cur_buf = k%NUM_BUFS; Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles{}); Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles{}); Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles{}); // Wait for K (part0) plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16)); plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1); if (k > 0) { plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } ku::tcgen05_after_thread_sync(); ku::utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true); ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2); // Wait for K (part1) plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16)); plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); ku::utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false); ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2); } if (k > 0) { // O += S(i-1)V(i-1) int cur_buf = (k-1)%NUM_BUFS; Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{}); Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{}); Tensor sS_divided = flat_divide(sS, Tile, _64>{})(_, _, _0{}, _); // (B_H/2, 64, 2) Tensor sV_divided = flat_divide(sV, Tile, _64>{})(_, _, _0{}, _); // (D_V/2, 64, 2) // Wait for S(i-1) and O to be scaled plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); // Wait for V (part0), and issue O += sS @ sV plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1); ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2); // Wait for V (part1), and issue O += sS @ sV plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false); ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2); } } } else if (warp_idx == 13) { // KV valid loading warp static_assert(B_TOPK == 128); if (lane_idx < 16) { CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { int cur_buf = k%NUM_BUFS; int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8); auto is_valid = [&](int rel_pos_in_lane, int index) -> char { int abs_pos = k*B_TOPK + lane_idx*8 + rel_pos_in_lane; return index >= 0 && index < params.s_kv && abs_pos < topk_length; }; char is_ks_valid_mask = \ is_valid(7, indices.a7) << 7 | is_valid(6, indices.a6) << 6 | is_valid(5, indices.a5) << 5 | is_valid(4, indices.a4) << 4 | is_valid(3, indices.a3) << 3 | is_valid(2, indices.a2) << 2 | is_valid(1, indices.a1) << 1 | is_valid(0, indices.a0) << 0; plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask; plan.bar_k_valid_ready[cur_buf].arrive(); } } } } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); } #endif } template __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) { Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); } template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { static_assert(D_QK == 576 || D_QK == 512); using Kernel = KernelTemplate; KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % Kernel::B_TOPK == 0); // To save some boundry checkings KU_ASSERT(params.h_q == Kernel::B_H); // To save some calculation KU_ASSERT(params.d_qk == D_QK); auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); auto tma_Q = cute::make_tma_copy( SM100_TMA_2SM_LOAD_NOSPLIT{}, make_tensor( make_gmem_ptr((bf16*)params.q), make_layout( shape_Q, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) ) ), (typename Kernel::template SmemLayoutQTiles){} ); auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); auto tma_O = cute::make_tma_copy( SM90_TMA_STORE{}, make_tensor( make_gmem_ptr((bf16*)params.out), make_layout( shape_O, make_stride(params.d_v, _1{}, params.h_q*params.d_v) ) ), (typename Kernel::template SmemLayoutOTiles<1>){} ); CUtensorMap tensor_map_kv; { uint64_t size[2] = {D_QK, (unsigned long)params.s_kv}; uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; uint32_t box_size[2] = {64, 1}; uint32_t elem_stride[2] = {1, 1}; CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tensor_map_kv, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, params.kv, size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); KU_ASSERT(res == CUresult::CUDA_SUCCESS); } TmaParams< decltype(shape_Q), decltype(tma_Q), decltype(shape_O), decltype(tma_O) > tma_params = { shape_Q, tma_Q, shape_O, tma_O, tensor_map_kv }; auto kernel = &sparse_attn_fwd_kernel; constexpr size_t smem_size = sizeof(typename Kernel::SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); cutlass::ClusterLaunchParams launch_params = { dim3(2*params.s_q, 1, 1), dim3(Kernel::NUM_THREADS, 1, 1), dim3(2, 1, 1), smem_size, params.stream }; KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster( launch_params, (void*)kernel, params, tma_params )); } } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head128/phase1.h ================================================ #pragma once #include "params.h" namespace sm100::fwd::head128 { template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head64/config.h ================================================ #pragma once #include #include #include "defines.h" namespace sm100::fwd::head64 { using namespace cute; template< typename Shape_Q_NoPE, typename TMA_Q_NoPE, typename Shape_Q_RoPE, typename TMA_Q_RoPE, typename Shape_O, typename TMA_O > struct TmaParams { Shape_Q_NoPE shape_Q_nope; TMA_Q_NoPE tma_Q_nope; Shape_Q_RoPE shape_Q_rope; TMA_Q_RoPE tma_Q_rope; Shape_O shape_O; TMA_O tma_O; CUtensorMap tensor_map_kv_nope; }; struct float2x2 { float2 lo, hi; }; constexpr int D_Q = 576; constexpr int D_K = 576; constexpr int D_V = 512; constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan constexpr int B_H = 64; constexpr int B_TOPK = 64; constexpr int NUM_BUFS = 3; constexpr int NUM_THREADS = 128 + 128 + 128; // 128 scale & exp threads, 128 TMA threads, 32 UTCMMA threads // Tensor memory columns namespace tmem_cols { // 0 ~ 256: output // 256 ~ 400: Q // 400 ~ 464: P constexpr int O = 0; constexpr int Q = 256; constexpr int Q_RoPE = 256 + 128; constexpr int P = 400; } using SmemLayoutQNoPE = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutQRoPE = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutO = SmemLayoutOTiles<8>; template using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutKNoPE = SmemLayoutKTiles<8>; using SmemLayoutV = decltype(coalesce( composition( SmemLayoutKNoPE{}, Layout, Int>, Stride, _1>>{} ) , Shape<_1, _1>{})); using SmemLayoutKRoPE = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<64>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutKNoPE_TiledMMA = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); // Re-view K-NoPE as B_TOPK*2 x D_V/2 for dual gemm using SmemLayoutKRoPE_TiledMMA = decltype(coalesce(tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<64/2>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); using SmemLayoutS = decltype(coalesce(tile_to_shape( UMMA::Layout_K_INTER_Atom{}, Shape, Int>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); struct SharedMemoryPlan { union { struct { array_aligned> _k_rope_pad; array_aligned> _k_pad[2]; // So that q_nope covers k[2] array_aligned> q_nope; } q_full; struct { array_aligned> k_rope; array_aligned> k_nope[NUM_BUFS]; } k; array_aligned> o; } u; float p_exchange_buf[4][32 * (B_TOPK/2)]; union { bf16 s[B_H*B_TOPK]; array_aligned> q_rope; } s_q_rope; char is_k_valid[NUM_BUFS][B_TOPK/8]; transac_bar_t bar_prologue_q_nope, bar_prologue_q_rope, bar_prologue_utccp_nope, bar_prologue_utccp_rope; transac_bar_t bar_qk_nope_done[NUM_BUFS], bar_qk_rope_done; // Pi = QKi^T (the nope part) done transac_bar_t bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. O, Si and Vi are free) transac_bar_t bar_kv_nope_ready[NUM_BUFS][2], bar_kv_rope_ready; transac_bar_t bar_p_free; transac_bar_t bar_so_ready; // S and O are ready transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; array_aligned tmem_start_addr; float rowwise_max_buf[128], rowwise_li_buf[128]; }; using TiledMMA_P = decltype(make_tiled_mma( SM100_MMA_F16BF16_WS_TS_NOELECT{} // Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: )); using TiledMMA_O = decltype(make_tiled_mma( SM100_MMA_F16BF16_WS_SS_NOELECT{} )); enum NamedBarriers : int { wg0_sync = 0, wg0_warp02_sync = 1, wg0_warp13_sync = 2, pepi_sync = 3, }; } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd::head64 { template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd::head64 { template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh ================================================ #pragma once #include "phase1.h" #include #include #include #include #include #include #include "params.h" #include "utils.h" #include "sm100/helpers.h" #include "sm100/prefill/sparse/common_subroutine.h" #include "config.h" namespace sm100::fwd::head64 { using namespace cute; /* Pipeline Overview: | Copy | MMA | Scale & Exp | KV0 KV1 KV2 P0 = QK0^T S0 = exp(P0) scale(O) w.r.t P0 P1 = QK1^T S1 = exp(P1) O += S0V0 KV3 scale(O) w.r.t P1 P2 = QK2^T S2 = exp(P2) O += S1V1 KV4 scale(O) w.r.t P2 P3 = QK3^T S3 = exp(P3) O += S2V2 KV5 scale(O) w.r.t P3 ... O += S(n-3)V(n-3) scale(O) w.r.t P(n-2) P(n-1) = QK(n-1)^T S(n-1) = exp(P(n-1)) O += S(n-2)V(n-2) scale(O) w.r.t P(n-1) O += S(n-1)V(n-1) */ using FwdMode = SparseAttnFwdMode; template __global__ void __launch_bounds__(NUM_THREADS, 1, 1) sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) // Grid shape: [s_q, 1, 1] const int s_q_idx = blockIdx.x; const int warp_idx = cutlass::canonical_warp_idx_sync(); const int lane_idx = threadIdx.x % 32; const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const int idx_in_warpgroup = threadIdx.x % 128; const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk; const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 // Define shared tensors extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] // Allocate tmem tensors TiledMMA tiled_mma_P = TiledMMA_P{}; TiledMMA tiled_mma_O = TiledMMA_O{}; // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm) Tensor tP = partition_fragment_C(tiled_mma_P, Shape, _128>{}); Tensor tQ_nope_part0 = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int<(D_V/2)/2>>{}) ); Tensor tQ_nope_part1 = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int<(D_V/2)/2>>{}) ); Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int<64/2>>{}) ); Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); tP.data().get() = tmem_cols::P; tQ_nope_part0.data().get() = tmem_cols::Q; tQ_nope_part1.data().get() = tmem_cols::Q + 64; tQ_rope.data().get() = tmem_cols::Q_RoPE; tO.data().get() = tmem_cols::O; if (warp_idx == 0) { if (elect_one_sync()) { // Copy Q if constexpr (HAVE_ROPE) { cute::prefetch_tma_descriptor(tma_params.tma_Q_rope.get_tma_descriptor()); } cute::prefetch_tma_descriptor(tma_params.tma_Q_nope.get_tma_descriptor()); plan.bar_prologue_q_nope.init(1); plan.bar_prologue_q_rope.init(1); fence_barrier_init(); if constexpr (HAVE_ROPE) { Tensor gQ_rope = tma_params.tma_Q_rope.get_tma_tensor(tma_params.shape_Q_rope)(_, _, s_q_idx); Tensor sQ_rope = make_tensor(make_smem_ptr(plan.s_q_rope.q_rope.data()), SmemLayoutQRoPE{}); ku::launch_tma_copy(tma_params.tma_Q_rope, gQ_rope, sQ_rope, plan.bar_prologue_q_rope, TMA::CacheHintSm90::EVICT_FIRST); } Tensor gQ_nope = tma_params.tma_Q_nope.get_tma_tensor(tma_params.shape_Q_nope)(_, _, s_q_idx); Tensor sQ_nope = make_tensor(make_smem_ptr(plan.u.q_full.q_nope.data()), SmemLayoutQNoPE{}); ku::launch_tma_copy(tma_params.tma_Q_nope, gQ_nope, sQ_nope, plan.bar_prologue_q_nope, TMA::CacheHintSm90::EVICT_FIRST); cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv_nope)); // Initialize other barriers plan.bar_prologue_utccp_rope.init(1); plan.bar_prologue_utccp_nope.init(1); CUTE_UNROLL for (int i = 0; i < NUM_BUFS; ++i) { plan.bar_qk_nope_done[i].init(1); plan.bar_sv_done[i].init(1); plan.bar_kv_nope_ready[i][0].init(1); plan.bar_kv_nope_ready[i][1].init(1); plan.bar_k_valid_ready[i].init(B_TOPK/8); plan.bar_k_valid_free[i].init(128); } plan.bar_p_free.init(128); plan.bar_so_ready.init(128); plan.bar_qk_rope_done.init(1); plan.bar_kv_rope_ready.init(64); fence_barrier_init(); } // Initialize TMEM cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); cute::TMEM::Allocator1Sm().release_allocation_lock(); } __syncthreads(); if (warpgroup_idx == 0) { // Scale & Exp warps // The following three numbers are // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V) // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi)) // - real_mi: real max logits, i.e. real_mi := max(Pi*scale) // where Pi is the i-th row of P, P := QK^T // mi and real_mi are always consistent within the two threads that // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update float mi = MAX_INIT_VAL; float li = 0.0f; float real_mi = -CUDART_INF_F; bf16* sS_base = plan.s_q_rope.s + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2); static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2; CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { // Wait for P NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); plan.bar_qk_nope_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); // Put the barrier wait here for more code reordering space ku::tcgen05_after_thread_sync(); // Load P float p[NUM_ELEMS_PER_THREAD]; retrieve_mask_and_reduce_p< NUM_ELEMS_PER_THREAD, tmem_cols::P, NamedBarriers::wg0_warp02_sync, NamedBarriers::wg0_warp13_sync, false >( plan.is_k_valid[k%NUM_BUFS], warp_idx, lane_idx, [&]() {plan.bar_p_free.arrive();}, plan.p_exchange_buf, p ); plan.bar_k_valid_free[k%NUM_BUFS].arrive(); // Get rowwise max of Pi float cur_pi_max = get_max(p); cur_pi_max *= params.sm_scale_div_log2; plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); real_mi = max(real_mi, cur_pi_max); bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); // By this point: // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127) // Calc scale factor, and scale li float new_max, scale_for_old; if (!should_scale_o) { // Don't scale O scale_for_old = 1.0f; new_max = mi; } else { new_max = max(cur_pi_max, mi); scale_for_old = exp2f(mi - new_max); } mi = new_max; // mi is still identical within each row // Calculate S nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2]; float cur_sum = get_s_from_p(s, p, params.sm_scale_div_log2, new_max); li = fma(li, scale_for_old, cur_sum); // Wait for last SV gemm, write S if (k > 0) { plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; i += 1) { *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4); } // Scale O if (k > 0 && should_scale_o) { // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before ku::tcgen05_after_thread_sync(); rescale_O(scale_for_old); ku::tcgen05_before_thread_sync(); } fence_view_async_shared(); plan.bar_so_ready.arrive(); } // Epilogue if (real_mi == -CUDART_INF_F) { // real_mi == -CUDART_INF_F <=> No valid TopK indices // We set li to 0 to fit the definition that li := exp(x[i] - mi) li = 0.0f; mi = -CUDART_INF_F; } // Exchange li plan.rowwise_li_buf[idx_in_warpgroup] = li; NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); li += plan.rowwise_li_buf[idx_in_warpgroup^64]; // Store mi and li if (idx_in_warpgroup < 64) { int global_index = s_q_idx*params.h_q + idx_in_warpgroup; float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; params.max_logits[global_index] = real_mi*CUDART_LN2_F; params.lse[global_index] = cur_lse; } // Wait for the last GEMM plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); // Fetch dO if necessary // Store O float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + (idx_in_warpgroup%64))*CUDART_L2E_F; float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi)); Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); constexpr int B_EPI = 64; Tensor tma_gO = flat_divide( tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx), Shape, Int>{} )(_, _, _0{}, _); Tensor sO_divided = flat_divide( sO, Shape, Int>{} )(_, _, _0{}, _); auto thr_tma = tma_params.tma_O.get_slice(_0{}); float2 o[B_EPI/2]; bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld if (!have_valid_indices) { // If there are no valid indices, we set o[i] to 0 and don't load from TMEM CUTE_UNROLL for (int i = 0; i < B_EPI/2; ++i) o[i].x = o[i].y = 0.0f; output_scale = 1.0f; } float2 output_scale_float2 = make_float2(output_scale, output_scale); bf16* sO_addrs[8]; CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) { sO_addrs[i] = &sO(idx_in_warpgroup%64, i*8); } CUTE_UNROLL for (int c = 0; c < 2; ++c) { // Each tile: 64 x 256 CUTE_UNROLL for (int k = 0; k < (D_V/4)/B_EPI; ++k) { // Load O from tO if (have_valid_indices) { ku::tmem_ld_32dp32bNx(tmem_cols::O + c*128 + k*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); } // Convert and store CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) { nv_bfloat162 o_bf16[4]; CUTE_UNROLL for (int j = 0; j < 4; ++j) { o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2); o_bf16[j] = __float22bfloat162_rn(o[i*4+j]); } *(uint128_t*)(sO_addrs[i] + (c*(D_V/2) + (idx_in_warpgroup/64)*(D_V/4) + k*B_EPI)*64) = *(uint128_t*)(o_bf16); } // Sync fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); if (warp_idx == 0 && elect_one_sync()) { int epi_chunk_idx = c*(D_V/2/B_EPI) + k; cute::copy( tma_params.tma_O, thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)), thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx)) ); } if (warp_idx == 1 && elect_one_sync()) { int epi_chunk_idx = c*(D_V/2/B_EPI) + (D_V/B_EPI/4) + k; cute::copy( tma_params.tma_O, thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)), thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx)) ); } } } if (warp_idx == 0) { cute::TMEM::Allocator1Sm().free(0, 512); } } else if (warpgroup_idx == 1) { // Producer warp for KV int warp_idx = cutlass::canonical_warp_idx_sync() - 4; constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/4)/NUM_WARPS; if (elect_one_sync()) { CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { int4 indices[NUM_LOCAL_ROWS_PER_WARP]; int max_indices = -1, min_indices = params.s_kv; CUTE_UNROLL for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); max_indices = max(max_indices, int4_max(indices[local_row])); min_indices = min(min_indices, int4_min(indices[local_row])); } bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1; bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS; if (k == 2) { plan.bar_prologue_utccp_nope.wait(0); // Since q_nope coincidences with k[2] } // Copy NoPE int cur_buf = k%NUM_BUFS; plan.bar_sv_done[cur_buf].wait((k/NUM_BUFS)&1^1); bf16* sK_nope_base = plan.u.k.k_nope[cur_buf].data() + warp_idx*4*64; auto load_kv_nope_part = [&](int part_idx) { CUTE_UNROLL for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { CUTE_UNROLL for (int local_col = part_idx*(D_V/2/64); local_col < (part_idx+1)*(D_V/2/64); ++local_col) { ku::tma_gather4( &(tma_params.tensor_map_kv_nope), plan.bar_kv_nope_ready[cur_buf][part_idx], sK_nope_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), local_col*64, indices[local_row], (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } } }; if (!should_skip_tma) { load_kv_nope_part(0); load_kv_nope_part(1); } else { // NOTE See head128/phase1.cuh for this TMA skipping technique CUTE_UNROLL for (int part_idx = 0; part_idx < 2; ++part_idx) plan.bar_kv_nope_ready[cur_buf][part_idx].complete_transaction(NUM_LOCAL_ROWS_PER_WARP*4*D_V/2*sizeof(bf16)); } } } } else { // MMA warp if (warp_idx == 8 && elect_one_sync()) { // S -> T copy for Q UMMA::SmemDescriptor sQ_nope_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(plan.u.q_full.q_nope.data()), tile_to_shape( UMMA::Layout_K_SW128_Atom{}, Shape, Int<64>>{} // We use this shape for dual gemm (TODO Link) ) ) ); UMMA::SmemDescriptor sQ_rope_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(plan.s_q_rope.q_rope.data()), tile_to_shape( UMMA::Layout_K_SW64_Atom{}, Shape, Int<32>>{} ) ) ); if constexpr (HAVE_ROPE) { // Copy the RoPE tile: 128 rows * 32 cols (64B) (in UTCCP's view), or 64 rows * 64 cols (in our view) plan.bar_prologue_q_rope.arrive_and_expect_tx(B_H*(D_Q-D_V)*sizeof(bf16)); plan.bar_prologue_q_rope.wait(0); ku::tcgen05_after_thread_sync(); CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 2; ++subtile_idx) { // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) SM100_UTCCP_128dp256bit_1cta::copy( sQ_rope_desc + (subtile_idx*32) / 16, tmem_cols::Q_RoPE + subtile_idx*8 ); } ku::umma_arrive_noelect(plan.bar_prologue_utccp_rope); } plan.bar_prologue_q_nope.arrive_and_expect_tx(B_H*D_V*sizeof(bf16)); plan.bar_prologue_q_nope.wait(0); ku::tcgen05_after_thread_sync(); CUTE_UNROLL for (int tile_idx = 0; tile_idx < D_V/64/2; ++tile_idx) { // A tile is 128 rows * 64 cols (128B) (in UTCCP's view), or 64 rows * 128 cols (in our view) CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) { // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) SM100_UTCCP_128dp256bit_1cta::copy( sQ_nope_desc + (tile_idx*(B_H*128*2) + subtile_idx*32) / 16, // Remember that 4 LSBs are not included tmem_cols::Q + tile_idx*32 + subtile_idx*8 ); } } ku::umma_arrive_noelect(plan.bar_prologue_utccp_nope); if constexpr (HAVE_ROPE) { plan.bar_prologue_utccp_rope.wait(0); } CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks+1; ++k) { if (k < num_k_blocks) { // Pi = QKi^T int cur_buf = k%NUM_BUFS; Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutKNoPE_TiledMMA{}); Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE_TiledMMA{}); plan.bar_p_free.wait(k&1^1); ku::tcgen05_after_thread_sync(); // Wait for K (RoPE) // P = Q(rope) @ K(rope)^T if constexpr (HAVE_ROPE) { plan.bar_kv_rope_ready.wait(k&1); ku::tcgen05_after_thread_sync(); ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true); ku::umma_arrive_noelect(plan.bar_qk_rope_done); } // Wait for K (NoPE) if (k == 0) { plan.bar_prologue_utccp_nope.wait(0); } Tensor sK_nope_divided = flat_divide(sK_nope, Tile, Int>{})(_, _, _0{}, _); CUTE_UNROLL for (int kv_nope_part_idx = 0; kv_nope_part_idx < 2; ++kv_nope_part_idx) { plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].arrive_and_expect_tx(B_TOPK*D_V/2*sizeof(bf16)); plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].wait((k/NUM_BUFS)&1); ku::tcgen05_after_thread_sync(); // P += Q(nope) @ K(nope)^T bool clear_accum = (!HAVE_ROPE) && kv_nope_part_idx == 0; ku::utcmma_ts(tiled_mma_P, kv_nope_part_idx ? tQ_nope_part1 : tQ_nope_part0, sK_nope_divided(_, _, kv_nope_part_idx), tP, clear_accum); } ku::umma_arrive_noelect(plan.bar_qk_nope_done[cur_buf]); } if (k > 0) { // O += S(i-1)V(i-1) int cur_buf = (k-1)%NUM_BUFS; Tensor sS = make_tensor(make_smem_ptr(plan.s_q_rope.s), SmemLayoutS{}); Tensor sV = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutV{}); // Wait for S(i-1) and O to be scaled plan.bar_so_ready.wait((k-1)&1); ku::tcgen05_after_thread_sync(); // O += sS @ sV ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == 1); ku::umma_arrive_noelect(plan.bar_sv_done[cur_buf]); } } } else if (warp_idx == 9) { // KV valid loading warp if (lane_idx < B_TOPK/8) { CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { char k_validness_mask = load_indices_and_generate_mask( lane_idx, gIndices + k*B_TOPK, params.s_kv, k*B_TOPK, topk_length ); int cur_buf = k%NUM_BUFS; plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); plan.is_k_valid[cur_buf][lane_idx] = k_validness_mask; plan.bar_k_valid_ready[cur_buf].arrive(); } } } else if (warp_idx == 10 || warp_idx == 11) { if constexpr (HAVE_ROPE) { int thread_idx = threadIdx.x - 10*32; constexpr int GROUP_SIZE = 8, NUM_GROUPS = 64/GROUP_SIZE, ROWS_PER_THREAD = B_TOPK/NUM_GROUPS; int group_idx = thread_idx / GROUP_SIZE, idx_in_group = thread_idx % GROUP_SIZE; Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE{}); bf16* sK_rope_base = &sK_rope(group_idx, idx_in_group*8); CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { int indices[ROWS_PER_THREAD]; CUTE_UNROLL for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) indices[local_row] = __ldg(gIndices + k*B_TOPK + group_idx + local_row*NUM_GROUPS); plan.bar_qk_rope_done.wait(k&1^1); CUTE_UNROLL for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) { int index = indices[local_row]; ku::cp_async_cacheglobal( params.kv + (int64_t)index*params.stride_kv_s_kv + 512 + idx_in_group*8, sK_rope_base + local_row*NUM_GROUPS*32, index >= 0 && index < params.s_kv ); // NOTE Using cp.async instead of TMA is faster here // NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside) } cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)&(plan.bar_kv_rope_ready)); } } } } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); } #endif } template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings KU_ASSERT(params.h_q == B_H); // To save some calculation KU_ASSERT(params.d_qk == D_QK); static_assert(D_QK == 576 || D_QK == 512); auto shape_Q_nope = make_shape(params.h_q, D_V, params.s_q); auto tma_Q_nope = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((bf16*)params.q), make_layout( shape_Q_nope, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) ) ), SmemLayoutQNoPE{} ); auto shape_Q_rope = make_shape(params.h_q, D_Q-D_V, params.s_q); auto tma_Q_rope = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((bf16*)params.q + D_V), make_layout( shape_Q_rope, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) ) ), SmemLayoutQRoPE{} ); auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); auto tma_O = cute::make_tma_copy( SM90_TMA_STORE{}, make_tensor( make_gmem_ptr((bf16*)params.out), make_layout( shape_O, make_stride(params.d_v, _1{}, params.h_q*params.d_v) ) ), SmemLayoutOTiles<1>{} ); CUtensorMap tensor_map_kv_nope; { uint64_t size[2] = {D_V, (unsigned long)params.s_kv}; uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; uint32_t box_size[2] = {64, 1}; uint32_t elem_stride[2] = {1, 1}; CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tensor_map_kv_nope, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, params.kv, size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); KU_ASSERT(res == CUresult::CUDA_SUCCESS); } TmaParams< decltype(shape_Q_nope), decltype(tma_Q_nope), decltype(shape_Q_rope), decltype(tma_Q_rope), decltype(shape_O), decltype(tma_O) > tma_params = { shape_Q_nope, tma_Q_nope, shape_Q_rope, tma_Q_rope, shape_O, tma_O, tensor_map_kv_nope }; auto kernel = &sparse_attn_fwd_kernel; constexpr size_t smem_size = sizeof(SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel<<>>(params, tma_params); KU_CHECK_KERNEL_LAUNCH(); } } ================================================ FILE: csrc/sm100/prefill/sparse/fwd/head64/phase1.h ================================================ #pragma once #include "params.h" namespace sm100::fwd::head64 { template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h ================================================ #pragma once #include "phase1.h" #include #include #include #include #include "defines.h" #include "params.h" namespace sm100::fwd_for_small_topk::head128 { using namespace cute; template struct KernelTemplate { using ArgT = SparseFwdArgT; static constexpr bool IS_DECODE = is_decode_v; static constexpr bool IS_PREFILL = !IS_DECODE; using fp8_e4m3 = cutlass::float_e4m3_t; using fp8_e8m0 = __nv_fp8_e8m0; struct TmaParamsForPrefill { CUtensorMap tensor_map_q; CUtensorMap tensor_map_kv; CUtensorMap tensor_map_o; }; struct TmaParamsForDecode { CUtensorMap tensor_map_q; CUtensorMap tensor_map_o; CUtensorMap tensor_map_o_accum; CUtensorMap tensor_map_kv_nope; CUtensorMap tensor_map_kv_rope; CUtensorMap tensor_map_extra_kv_nope; // Only available if extra_kv is enabled CUtensorMap tensor_map_extra_kv_rope; }; using TmaParams = std::conditional_t< IS_DECODE, TmaParamsForDecode, TmaParamsForPrefill >; static_assert(D_QK == 512); static constexpr int D_Q = D_QK; static constexpr int D_K = D_QK; static constexpr int D_V = 512; static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan static constexpr int H_Q = 128; // For 2 CTAs static constexpr int B_TOPK = 64; // For 2 CTAs static constexpr int NUM_THREADS = 128*4; static constexpr int NUM_WORKER_THREADS = IS_PREFILL ? (128 + 4 + (B_TOPK/8) + 1 + 128)*2 + 1 : (128 + 128 + 1 + 32 + 2 + 128)*2; // For non-decode mode, we have 4 (half-)KV buffers // For decode mode, we have 3 (half-)KV buffers with two raw KV buffers static constexpr int NUM_K_BUFS = IS_DECODE ? 3 : 4; static constexpr int NUM_RAW_K_BUFS = IS_DECODE ? 2 : 0; static constexpr int NUM_INDEX_BUFS = IS_DECODE ? 4 : 4; static constexpr int D_NOPE = 448; static constexpr int D_ROPE = 64; static constexpr int TMA_K_STRIDE_FOR_DECODING = D_NOPE + 2*D_ROPE; static constexpr int NUM_SCALES_EACH_TOKEN = 8; // 7 scales + 1 padding static constexpr int B_EPI = 64; // Epilogue block size for normal case (i.e. prefill or non-splitkv decoding) static constexpr int B_EPI_SPLITKV = 32; // Epilogue block size for splitkv decoding static constexpr int NUM_EPI_SPLITKV_BUFS = 4; // The number of epilogue buffers for splitkv decoding static_assert((H_Q/2)*D_Q*sizeof(bf16) >= NUM_EPI_SPLITKV_BUFS*(H_Q/2)*(B_EPI_SPLITKV*2)*sizeof(float)); // Tensor memory columns struct tmem_cols { // 0 ~ 256: Output accumulator // 256 ~ 384: Q // 384 ~ 448: P static constexpr int O = 0; static constexpr int Q = 256; static constexpr int P = 384; }; struct SharedMemoryPlan { array_aligned Q; // Will be output for epilogue array_aligned K[NUM_K_BUFS]; array_aligned K_raw[NUM_RAW_K_BUFS]; array_aligned S; float P_exchange[4][(H_Q/2/2)*(B_TOPK/2)]; float rowwise_max_buf[128], rowwise_li_buf[128]; CUTE_ALIGNAS(16) char is_k_valid[NUM_INDEX_BUFS][B_TOPK/8]; CUTE_ALIGNAS(16) int tma_coord[NUM_INDEX_BUFS][B_TOPK]; CUTE_ALIGNAS(16) fp8_e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN/2]; transac_bar_t bar_sQ_full, bar_tQ_empty, bar_tQ_full; transac_bar_t bar_tOut_full, bar_tOut_empty; transac_bar_t bar_KV_full[NUM_K_BUFS], bar_KV_empty[NUM_K_BUFS]; transac_bar_t bar_P_empty; transac_bar_t bar_QK_done, bar_SV_done; transac_bar_t bar_S_O_full; transac_bar_t bar_li_full, bar_li_empty; // The following barriers are prefill-only transac_bar_t bar_clc_full, bar_clc_empty; // The following barriers are decode-only transac_bar_t bar_raw_KV_full[NUM_RAW_K_BUFS], bar_raw_KV_empty[NUM_RAW_K_BUFS]; transac_bar_t bar_valid_coord_scales_full[NUM_INDEX_BUFS], bar_valid_coord_scales_empty[NUM_INDEX_BUFS]; ku::CLCResponseObj clc_response_obj; array_aligned tmem_start_addr; }; using TiledMMA_P = decltype(make_tiled_mma( SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} )); // *2 for dual gemm using TiledMMA_O = decltype(make_tiled_mma( SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, Layout>{}, Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] )); struct barrier_ids { static constexpr int WG0_SYNC = 0; static constexpr int WG2_SYNC = 1; static constexpr int WG2_WARP02_SYNC = 2; static constexpr int WG2_WARP13_SYNC = 3; }; static __device__ void sparse_attn_fwd_kernel_devfunc(const ArgT ¶ms, const TmaParams &tma_params); static void run(const ArgT& params); }; } ================================================ FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd_for_small_topk::head128 { template void run_fwd_for_small_topk_phase1_kernel(const SparseAttnDecodeParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm100::fwd_for_small_topk::head128 { template void run_fwd_for_small_topk_phase1_kernel(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh ================================================ #pragma once #include "phase1.h" #include #include #include #include #include #include "params.h" #include "utils.h" #include "sm100/prefill/sparse/common_subroutine.h" #include "sm100/helpers.h" #include "config.h" namespace sm100::fwd_for_small_topk::head128 { using namespace cute; using FwdMode = SparseAttnFwdMode; template __device__ void KernelTemplate::sparse_attn_fwd_kernel_devfunc(const ArgT ¶ms, const TmaParams &tma_params) { #ifdef KERUTILS_ENABLE_SM100A // Grid shape: [2*s_q, 1, 1] for prefilling, [2*s_q, num_sm_parts, 1] for decoding // Cluster shape: [2, 1, 1] const int warp_idx = cutlass::canonical_warp_idx_sync(); const int lane_idx = threadIdx.x % 32; const int warpgroup_idx = cutlass::canonical_warp_group_idx(); const int idx_in_warpgroup = threadIdx.x % 128; const int cta_idx = block_id_in_cluster().x; extern __shared__ char wksp_buf[]; SharedMemoryPlan &smem = *reinterpret_cast(wksp_buf); if (warp_idx == 0 && elect_one_sync()) { cute::prefetch_tma_descriptor(&tma_params.tensor_map_q); cute::prefetch_tma_descriptor(&tma_params.tensor_map_o); if constexpr (IS_DECODE) { cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope); cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope); } else { cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv); } } else if (warp_idx == 1 && elect_one_sync()) { smem.bar_sQ_full.init(1); smem.bar_tQ_empty.init(1); smem.bar_tQ_full.init(1); smem.bar_tOut_full.init(1); smem.bar_tOut_empty.init(256); smem.bar_P_empty.init(256); smem.bar_QK_done.init(1); smem.bar_SV_done.init(1); smem.bar_S_O_full.init(256); smem.bar_li_full.init(H_Q/2); smem.bar_li_empty.init(128); if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) { smem.bar_clc_full.init(1); smem.bar_clc_empty.init(NUM_WORKER_THREADS); } fence_barrier_init(); } else if (warp_idx == 2) { cute::TMEM::Allocator2Sm().allocate(512, smem.tmem_start_addr.data()); KU_TRAP_ONLY_DEVICE_ASSERT(smem.tmem_start_addr.data()[0] == 0); cute::TMEM::Allocator2Sm().release_allocation_lock(); } else if (warp_idx == 3 && elect_one_sync()) { CUTE_UNROLL for (int i = 0; i < NUM_K_BUFS; ++i) { smem.bar_KV_full[i].init(IS_PREFILL ? 1 : (128/32)*2+1); smem.bar_KV_empty[i].init(1); } CUTE_UNROLL for (int i = 0; i < NUM_INDEX_BUFS; ++i) { smem.bar_valid_coord_scales_full[i].init(IS_PREFILL ? B_TOPK/8 : 32); smem.bar_valid_coord_scales_empty[i].init(IS_PREFILL ? 128 : (128 + (cta_idx==1) + 2 + 128)); } if constexpr (IS_DECODE) { CUTE_UNROLL for (int i = 0; i < NUM_RAW_K_BUFS; ++i) { smem.bar_raw_KV_full[i].init(1); smem.bar_raw_KV_empty[i].init(128); } } fence_barrier_init(); } ku::barrier_cluster_arrive_relaxed(); ku::barrier_cluster_wait_acquire(); struct OuterloopArgs { bool outer_loop_phase; int batch_idx, s_q_idx; int start_block_idx, end_block_idx; int topk_length; int extra_topk_length, num_orig_kv_blocks; // extra-KV related bool is_no_split; int n_split_idx; // splitkv related }; auto run_outer_loop = [&](auto loop_body) -> bool { int outer_loop_phase = false; if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { int s_q_idx = blockIdx.x / 2; DecodingSchedMeta sched_meta; KU_LDG_256( params.tile_scheduler_metadata_ptr + blockIdx.y, &sched_meta, ".nc", "no_allocate", "evict_normal", "256B" ); if (sched_meta.begin_req_idx >= params.b) { return 0; } #pragma unroll 1 for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK; bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false); int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx); // start_block_idx = 0; // end_block_idx = total_topk_padded / B_TOPK; // is_split = false; // n_split_idx = 0; OuterloopArgs args = { (bool)outer_loop_phase, batch_idx, s_q_idx, start_block_idx, end_block_idx, topk_length, extra_topk_length, orig_topk_padded / B_TOPK, !is_split, n_split_idx }; loop_body(args); outer_loop_phase ^= 1; } } else { // Prefill mode. Use CLC to allocate different s_q (for decoding, different batches + s_q) to different workers ku::CLCResult next_job = {true, (int)blockIdx.x, IS_PREFILL ? 0 : (int)blockIdx.y, 0}; CUTE_NO_UNROLL while (next_job.is_valid) { int s_q_idx = next_job.x / 2; int batch_idx = IS_PREFILL ? 0 : next_job.y; int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + (IS_PREFILL?s_q_idx:batch_idx)) : params.topk; if constexpr (IS_PREFILL) { int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 OuterloopArgs args = { (bool)outer_loop_phase, 0, s_q_idx, 0, num_k_blocks, topk_length }; loop_body(args); } else { int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 OuterloopArgs args = { (bool)outer_loop_phase, batch_idx, s_q_idx, 0, total_topk_padded / B_TOPK, topk_length, extra_topk_length, orig_topk_padded / B_TOPK, false, 0 }; loop_body(args); } smem.bar_clc_full.wait(outer_loop_phase); next_job = ku::get_clc_query_response(smem.clc_response_obj); smem.bar_clc_empty.arrive(0u); outer_loop_phase ^= 1; } } return outer_loop_phase; }; if (warpgroup_idx == 0) { // Q fetching and O writing back warpgroup cutlass::arch::warpgroup_reg_alloc<176>(); bf16* sO_addrs[B_EPI/8]; CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) { Tensor sO = make_tensor(make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout()); sO_addrs[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*(D_V/2) + i*8); } float* sO_accum_addrs[B_EPI_SPLITKV/4]; if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { // If split-KV is enabled, we need to store back O in float32 // We view Q buffer (with shape 64 x 512, bf16) as 4 buffers with shape (H_Q/2) x (B_EPI_SPLITKV*2), float32 Tensor sO_accum = make_tensor(make_smem_ptr((float*)smem.Q.data()), ku::make_umma_canonical_k_major_layout()); CUTE_UNROLL for (int i = 0; i < B_EPI_SPLITKV/4; ++i) { sO_accum_addrs[i] = &sO_accum(idx_in_warpgroup%64, i*4) + (idx_in_warpgroup >= 64 ? (H_Q/2)*B_EPI_SPLITKV : 0); } } auto perform_o_copy_out = [&](const OuterloopArgs &args, bool is_last_o) { // outer_loop_phase is the loop phase corresponding to s_q_idx // Get li (output_scale actually) smem.bar_li_full.wait(args.outer_loop_phase); float output_scale = smem.rowwise_li_buf[idx_in_warpgroup%64]; float2 output_scale_float2 = float2 {output_scale, output_scale}; smem.bar_li_empty.arrive(); // Retrieve and store O, and calculate delta := sum(O*dO, dim=-1) if FWD_MODE is Recompute smem.bar_tOut_full.wait(args.outer_loop_phase); if (is_last_o && elect_one_sync()) { cudaTriggerProgrammaticLaunchCompletion(); } if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { CUTE_UNROLL for (int k = 0; k < (D_V/2)/B_EPI; ++k) { float2 o[B_EPI/2]; ku::tmem_ld_32dp32bNx(tmem_cols::O + k*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); if (k == (D_V/2)/B_EPI-1) { smem.bar_tOut_empty.arrive(0u); } CUTE_UNROLL for (int i = 0; i < B_EPI/8; ++i) { nv_bfloat162 o_bf16[4]; CUTE_UNROLL for (int j = 0; j < 4; ++j) { o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2); o_bf16[j] = __float22bfloat162_rn(o[i*4+j]); } bf16* o_do_addr = sO_addrs[i] + k*B_EPI*(H_Q/2); if (k == 0 && i == 0) { smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability } ku::st_shared(o_do_addr, *(__int128_t*)o_bf16); } } fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); if (warp_idx == 0 && elect_one_sync()) { SM90_TMA_STORE_5D::copy( &tma_params.tensor_map_o, smem.Q.data(), 0, cta_idx*(H_Q/2), 0, args.s_q_idx, IS_DECODE ? args.batch_idx : 0 ); cute::tma_store_arrive(); } } else { CUTE_UNROLL for (int k = 0; k < (D_V/2)/B_EPI_SPLITKV; ++k) { int cur_buf_idx = k % NUM_EPI_SPLITKV_BUFS; if (k == 0) { cute::tma_store_wait<0>(); } else { cute::tma_store_wait(); } NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); float o[B_EPI_SPLITKV]; ku::tmem_ld_32dp32bNx(tmem_cols::O + k*B_EPI_SPLITKV, o); cutlass::arch::fence_view_async_tmem_load(); if (k == (D_V/2)/B_EPI_SPLITKV-1) { smem.bar_tOut_empty.arrive(0u); } CUTE_UNROLL for (int i = 0; i < B_EPI_SPLITKV/4; ++i) { CUTE_UNROLL for (int j = 0; j < 4; j += 2) { *(float2*)(o + i*4 + j) = ku::float2_mul(float2 {o[i*4+j], o[i*4+j+1]}, output_scale_float2); } if (k == 0 && i == 0) { smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability } ku::st_shared( sO_accum_addrs[i] + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2), *(__int128_t*)(o + i*4) ); } fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); if constexpr (IS_DECODE) { // Otherwise nvcc complains about `tma_params` doesn't have `tensor_map_o_accum` float* cur_buf_base = (float*)smem.Q.data() + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2); if (warp_idx == 0 && elect_one_sync()) { SM90_TMA_STORE_5D::copy( &tma_params.tensor_map_o_accum, cur_buf_base, 0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32), args.s_q_idx, args.n_split_idx ); cute::tma_store_arrive(); } else if (warp_idx == 1 && elect_one_sync()) { SM90_TMA_STORE_5D::copy( &tma_params.tensor_map_o_accum, cur_buf_base + (H_Q/2)*B_EPI_SPLITKV, 0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32) + (D_V/2)/32, args.s_q_idx, args.n_split_idx ); cute::tma_store_arrive(); } } } } }; OuterloopArgs last_args; last_args.batch_idx = -1; bool final_outer_loop_phase = \ run_outer_loop([&](const OuterloopArgs &args) { // Copy Q for this round if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { cute::tma_store_wait<0>(); NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); // Since we use two warps to issue TMA during FwdMode::DecodeWithSplitKV } if (warp_idx == 0 && elect_one_sync()) { // Wait for sQ to become empty, and issue G -> S copy for Q if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) { cute::tma_store_wait<0>(); // This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document) } int stride_q_b_div_stride_q_s_q = 0; if constexpr (IS_DECODE) { stride_q_b_div_stride_q_s_q = params.stride_q_b / params.stride_q_s_q; } SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy( &tma_params.tensor_map_q, (uint64_t*)&smem.bar_sQ_full, (uint64_t)TMA::CacheHintSm90::EVICT_FIRST, smem.Q.data(), 0, cta_idx*(H_Q/2), 0, 0, (IS_DECODE ? args.batch_idx*stride_q_b_div_stride_q_s_q : 0) + args.s_q_idx ); // Wait for sQ to be ready, and issue S -> T copy for Q if (cta_idx == 0) { smem.bar_sQ_full.arrive_and_expect_tx(H_Q*D_Q*sizeof(bf16)); smem.bar_sQ_full.wait(args.outer_loop_phase); smem.bar_tQ_empty.wait(args.outer_loop_phase^1); ku::tcgen05_after_thread_sync(); UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( make_tensor( make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout<(H_Q/2)*2, 64, 128>() ) ); CUTE_UNROLL for (int tile_idx = 0; tile_idx < D_Q/64/2; ++tile_idx) { // A tile is 128 rows * 64 cols in UTCCP's view, or 64 rows * 128 cols in our view CUTE_UNROLL for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) { // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) // NOTE Using `sQ_desc+((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4)` leads to IMA, doesn't know why UMMA::SmemDescriptor cur_sQ_desc = sQ_desc; cur_sQ_desc.lo += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4); // uint64_t cur_sQ_desc = sQ_desc; // cur_sQ_desc += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4); SM100_UTCCP_128dp256bit_2cta::copy( cur_sQ_desc, tmem_cols::Q + tile_idx*32 + subtile_idx*8 ); } } ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tQ_full, 1|2); } } if (last_args.batch_idx != -1) { perform_o_copy_out(last_args, false); } else { smem.bar_tQ_full.wait(args.outer_loop_phase); // To prevent double arrive } last_args = args; }); if (last_args.batch_idx != -1) { cute::tma_store_wait<0>(); NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); perform_o_copy_out(last_args, true); } if (warp_idx == 0) { cute::TMEM::Allocator2Sm().free(0, 512); } } else if (warpgroup_idx == 1) { // KV fetching threads for prefill, dequant threads for decoding cutlass::arch::warpgroup_reg_dealloc<80>(); RingBufferState rs; if constexpr (!IS_DECODE) { const int warp_idx = cutlass::canonical_warp_idx(); // Using `warp_idx` without `__shfl_sync` is faster if (elect_one_sync()) { // KV fetching threads run_outer_loop([&](const OuterloopArgs &args) { int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q; int64_t cache_hint = ku::create_simple_cache_policy(); static constexpr int NUM_ROWS_PER_THREAD = B_TOPK / 4; CUTE_NO_UNROLL for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { auto [k_buf_idx, k_bar_phase] = rs.get(); int cur_indices[NUM_ROWS_PER_THREAD]; CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/8; local_row += 1) { int row = local_row*(4*8) + (warp_idx-4)*8; KU_LDG_256( gIndices + k*B_TOPK + row, cur_indices + local_row*8, ".nc", "no_allocate", "evict_first", "256B" ); } smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/4; local_row += 1) { int row = (warp_idx-4)*8 + (local_row/2)*(4*8) + (local_row%2)*4; int4 indices = *(int4*)(cur_indices+local_row*4); static_assert(D_K == 512); CUTE_UNROLL for (int local_col = 0; local_col < (D_K/64)/2; ++local_col) { ku::tma_gather4_cta_group_2( &tma_params.tensor_map_kv, smem.bar_KV_full[k_buf_idx], smem.K[k_buf_idx].data() + row*64 + local_col*64*B_TOPK, local_col*64 + cta_idx*(D_K/2), indices, cache_hint ); } } rs.update(); } }); } } else { // 8 threads per token struct IsCTA0 {}; struct IsCTA1 {}; auto launch_dequant_wg = [&](auto cta_id_t) { static constexpr bool IS_CTA1 = std::is_same::value; constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = (IS_CTA1 ? 256-64 : 256) / (GROUP_SIZE*8); int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE; Tensor nope0 = make_tensor(make_smem_ptr(smem.K[0].data()), ku::make_umma_canonical_k_major_layout()); bf16* nope0_base = &nope0(group_idx, idx_in_group*8); fp8_e4m3* raw_nope0_base = smem.K_raw[0].data() + group_idx*(D_K/2) + idx_in_group*8; run_outer_loop([&](const OuterloopArgs &args) { CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { auto [k_buf_idx, k_bar_phase] = rs.get(); auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get(); auto [index_buf_idx, index_bar_phase] = rs.get(); fp8_e4m3* raw_nope_base = raw_nope0_base + raw_k_buf_idx * (B_TOPK*(D_K/2)); auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t { return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*(D_K/2) + local_col_idx*(GROUP_SIZE*8)); }; bf16* nope_base = nope0_base + k_buf_idx * (B_TOPK*(D_K/2)); uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(nope_base); auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) { asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n" : : "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16) ); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS }; smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); smem.bar_raw_KV_full[raw_k_buf_idx].wait(raw_k_bar_phase); CUTE_UNROLL for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { int row_idx = local_row_idx*NUM_GROUPS + group_idx; bf16 scales[4]; fp8_e8m0 scales_e8m0[4]; *(uint32_t*)scales_e8m0 = *(uint32_t*)(smem.scales[index_buf_idx][row_idx]); *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); CUTE_UNROLL for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { ku::nve4m3x2 data_fp8[4]; ku::nvbf16x2 data_bf16[4]; *(uint64_t*)data_fp8 = cur_data_fp8x8; if (local_col_idx+1 < COLS_PER_GROUP) cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); bf16 scale = scales[local_col_idx]; CUTE_UNROLL for (int i = 0; i < 4; ++i) { data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); } if (local_row_idx == 0 && local_col_idx == 0) { smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); } st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); } } fence_view_async_shared(); // NOTE Should we use shared::cluster here? __syncwarp(); smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); smem.bar_raw_KV_empty[raw_k_buf_idx].arrive(); if (elect_one_sync()) { smem.bar_KV_full[k_buf_idx].arrive(0u); } rs.update(); } }); }; if (cta_idx == 0) { launch_dequant_wg(IsCTA0{}); } else { launch_dequant_wg(IsCTA1{}); } } } else if (warpgroup_idx == 2) { cutlass::arch::warpgroup_reg_dealloc<80>(); RingBufferState rs; if (warp_idx == 8 && cta_idx == 0 && elect_one_sync()) { // UMMA thread TiledMMA tiled_mma_P = TiledMMA_P{}; TiledMMA tiled_mma_O = TiledMMA_O{}; Tensor tP = partition_fragment_C(tiled_mma_P, Shape, Int>{}); Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A( partition_shape_A(tiled_mma_P, Shape, Int>{}) ); tP.data().get() = tmem_cols::P; tO.data().get() = tmem_cols::O; tQ.data().get() = tmem_cols::Q; run_outer_loop([&](const OuterloopArgs &args) { smem.bar_tQ_full.wait(args.outer_loop_phase); // Issue P = Q K^T auto issue_P = [&](int k, int rs_offset) { auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get(); auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>(); smem.bar_P_empty.wait(bar_phase^1); if constexpr (IS_PREFILL) { smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_K*sizeof(bf16)); } else { // RoPE only smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16)); } smem.bar_KV_full[k_buf_idx].wait(k_bar_phase); ku::tcgen05_after_thread_sync(); Tensor sK = make_tensor( make_smem_ptr(smem.K[k_buf_idx].data()), ku::make_umma_canonical_k_major_layout() ); ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true); ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_QK_done, 1|2); }; // Issue O += S V auto issue_O = [&](int k, int rs_offset) { auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get(); auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>(); smem.bar_S_O_full.wait(bar_phase); if (k == args.start_block_idx) { smem.bar_tOut_empty.wait(args.outer_loop_phase^1); } ku::tcgen05_after_thread_sync(); Tensor sS = make_tensor( make_smem_ptr(smem.S.data()), ku::make_umma_canonical_k_major_layout() ); Tensor sV = make_tensor( make_smem_ptr(smem.K[k_buf_idx].data()), ku::make_umma_canonical_mn_major_layout() ); ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == args.start_block_idx); ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_SV_done, 1|2); ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_KV_empty[k_buf_idx], 1|2); }; CUTE_NO_UNROLL for (int k = args.start_block_idx; k < args.end_block_idx+1; ++k) { if (k < args.end_block_idx) { issue_P(k, 0); } if (k == args.end_block_idx-1) { ku::umma_arrive_2x1SM_noelect(smem.bar_tQ_empty); } if (k > args.start_block_idx) { issue_O(k-1, -1); } if (k != args.end_block_idx) { rs.update(); } } ku::tcgen05_before_thread_sync(); ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tOut_full, 1|2); }); } else if (warp_idx == 8 && cta_idx == 1 && elect_one_sync()) { if constexpr (IS_DECODE) { // KV RoPE fetching warp run_outer_loop([&](const OuterloopArgs &args) { CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { auto [index_buf_idx, index_bar_phase] = rs.get(); auto [k_buf_idx, k_bar_phase] = rs.get(); smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); CUTE_UNROLL for (int row = 0; row < B_TOPK; row += 4) { int4 cur_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row); ku::tma_gather4_cta_group_2( block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope, smem.bar_KV_full[k_buf_idx], smem.K[k_buf_idx].data() + (D_NOPE-D_K/2)*B_TOPK + row*D_ROPE, 0, cur_indices, (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); rs.update(); } }); } } else if (warp_idx == 9) { // KV validness loading warp (for prefill), Indices transformation warp (for decode, Responsible for generating: TMA coordinates, scale factors, and valid masks) if constexpr (IS_PREFILL) { if (lane_idx < B_TOPK/8) { run_outer_loop([&](const OuterloopArgs &args) { int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q; CUTE_NO_UNROLL for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { char k_validness_mask = load_indices_and_generate_mask( lane_idx, gIndices + k*B_TOPK, params.s_kv, k*B_TOPK, args.topk_length ); auto [indices_buf_idx, indices_bar_phase] = rs.get(); smem.bar_valid_coord_scales_empty[indices_buf_idx].wait(indices_bar_phase^1); smem.is_k_valid[indices_buf_idx][lane_idx] = k_validness_mask; smem.bar_valid_coord_scales_full[indices_buf_idx].arrive(); rs.update(); } }); } } else { static_assert(B_TOPK == 64); // Each thread is responsible for 2 tokens static constexpr int tma_coords_step_per_token = 576/TMA_K_STRIDE_FOR_DECODING; int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE_FOR_DECODING; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512 int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE_FOR_DECODING; uint8_t* k_scales_ptr = (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE); uint8_t* extra_k_scales_ptr = (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE); run_outer_loop([&](const OuterloopArgs &args) { int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*args.s_q_idx; int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*args.s_q_idx; struct IsOrigBlock {}; struct IsExtraBlock {}; auto process_one_block = [&](int block_idx, auto is_extra_block_t) { auto [index_buf_idx, index_bar_phase] = rs.get(); static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size; int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block; [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row; uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr; int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block; int abs_pos, my_indices[2]; if (!IS_EXTRA_BLOCK) { abs_pos = block_idx*B_TOPK + lane_idx*2; *(int2*)my_indices = __ldg((int2*)(indices + abs_pos)); } else { abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2; *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos)); } smem.bar_valid_coord_scales_empty[index_buf_idx].wait(index_bar_phase^1); int tma_coords[2]; fp8_e8m0 scales[2*(NUM_SCALES_EACH_TOKEN/2)]; char valid_mask = 0; CUTE_UNROLL for (int i = 0; i < 2; ++i) { int block_idx, idx_in_block; block_idx = (unsigned int)my_indices[i] / cur_block_size; idx_in_block = (unsigned int)my_indices[i] % cur_block_size; bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length)); valid_mask |= is_token_valid << i; tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN. int64_t offset = block_idx*cur_k_block_stride + (idx_in_block*8 + (cta_idx == 1 ? 4 : 0)); // Each token has 7 scale factors with an extra 1B padding uint32_t scalesx4 = is_token_valid ? __ldg((uint32_t*)(cur_k_scales_ptr + offset)) : 0; *(uint32_t*)(scales+i*(NUM_SCALES_EACH_TOKEN/2)) = scalesx4; } valid_mask <<= lane_idx%4*2; valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1); valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2); *(uint64_t*)(smem.scales[index_buf_idx] + lane_idx*2) = *(uint64_t*)scales; *(int2*)(smem.tma_coord[index_buf_idx] + lane_idx*2) = *(int2*)tma_coords; if (lane_idx%4 == 0) smem.is_k_valid[index_buf_idx][lane_idx/4] = valid_mask; smem.bar_valid_coord_scales_full[index_buf_idx].arrive(); rs.update(); }; CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { process_one_block(block_idx, IsOrigBlock{}); } CUTE_NO_UNROLL for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) { process_one_block(block_idx, IsExtraBlock{}); } }); } } else if (warp_idx >= 10 && elect_one_sync()) { if constexpr (IS_PREFILL) { if (warp_idx == 10) { // CLC Producer thread run_outer_loop([&](const OuterloopArgs &args) { if (cta_idx == 0) { smem.bar_clc_empty.wait(args.outer_loop_phase^1); ku::issue_clc_query_multicast_cluster_all(smem.bar_clc_full, smem.clc_response_obj); } smem.bar_clc_full.arrive_and_expect_tx(sizeof(smem.clc_response_obj)); }); } } else { // Raw KV NoPE Producer thread run_outer_loop([&](const OuterloopArgs &args) { CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get(); auto [index_buf_idx, index_bar_phase] = rs.get(); smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); smem.bar_raw_KV_empty[raw_k_buf_idx].wait(raw_k_bar_phase^1); int4 nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + (warp_idx == 10 ? 0 : 4)); CUTE_UNROLL for (int row = (warp_idx == 10 ? 0 : 4); row < B_TOPK; row += 8) { int4 cur_indices = nxt_indices; if (row+8 < B_TOPK) nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row + 8); ku::tma_gather4( block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope, smem.bar_raw_KV_full[raw_k_buf_idx], smem.K_raw[raw_k_buf_idx].data() + row*(D_K/2), cta_idx*(D_K/2), cur_indices, (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } if (warp_idx == 10) { smem.bar_raw_KV_full[raw_k_buf_idx].arrive_and_expect_tx(B_TOPK*(D_K/2)*sizeof(fp8_e4m3)); } smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); rs.update(); } }); } } } else { // Scale & Exp threads cutlass::arch::warpgroup_reg_alloc<176>(); int local_warp_idx = warp_idx - 12; bf16* sS_base = smem.S.data() + (local_warp_idx >= 2 ? (H_Q/2)*(B_TOPK/2) : 0) + (idx_in_warpgroup%64)*8; RingBufferState rs; run_outer_loop([&](const OuterloopArgs &args) { // For definition and consistency about `mi`, `li`, and `real_mi`, plz refer to head64 prefill float mi = MAX_INIT_VAL; float li = 0.0f; float real_mi = -CUDART_INF_F; static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2; CUTE_NO_UNROLL for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { auto [k_buf_idx, k_bar_phase] = rs.get(); auto [indices_buf_idx, indices_bar_phase] = rs.get(); auto [_, bar_phase] = rs.get<1>(); // NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf smem.bar_valid_coord_scales_full[indices_buf_idx].wait(indices_bar_phase); // Get P from TMEM float p[NUM_ELEMS_PER_THREAD]; smem.bar_QK_done.wait(bar_phase); ku::tcgen05_after_thread_sync(); retrieve_mask_and_reduce_p< NUM_ELEMS_PER_THREAD, tmem_cols::P, barrier_ids::WG2_WARP02_SYNC, barrier_ids::WG2_WARP13_SYNC, false >( smem.is_k_valid[indices_buf_idx], local_warp_idx, lane_idx, [&]() {smem.bar_P_empty.arrive(0u);}, smem.P_exchange, p ); // Get rowwise max of P float cur_pi_max = get_max(p); cur_pi_max *= params.sm_scale_div_log2; smem.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; NamedBarrier::arrive_and_wait(64, barrier_ids::WG2_WARP02_SYNC + (local_warp_idx&1)); cur_pi_max = max(cur_pi_max, smem.rowwise_max_buf[idx_in_warpgroup^64]); real_mi = max(real_mi, cur_pi_max); bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); // Calc scale factor, and scale li float new_max, scale_for_old; if (!should_scale_o) { // Don't scale O scale_for_old = 1.0f; new_max = mi; } else { new_max = max(cur_pi_max, mi); scale_for_old = exp2f(mi - new_max); } mi = new_max; // mi is still identical within each row // Calculate S nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2]; float cur_sum = get_s_from_p(s, p, params.sm_scale_div_log2, new_max); li = fmaf(li, scale_for_old, cur_sum); // Store S smem.bar_SV_done.wait(bar_phase^1); CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; ++i) { ku::st_shared(sS_base + i*8*(H_Q/2), *(__int128_t*)(s + i*4)); } // Rescale O if (k > 0 && should_scale_o) { ku::tcgen05_after_thread_sync(); rescale_O(scale_for_old); ku::tcgen05_before_thread_sync(); } fence_view_async_shared(); smem.bar_S_O_full.arrive(0u); smem.bar_valid_coord_scales_empty[indices_buf_idx].arrive(); rs.update(); } if (real_mi == -CUDART_INF_F) { // real_mi == -CUDART_INF_F <=> No valid TopK indices // We set li to 0 to fit the definition that li := exp(x[i] - mi) li = 0.0f; mi = -CUDART_INF_F; } // Reduce li smem.bar_li_empty.wait(args.outer_loop_phase^1); smem.rowwise_li_buf[idx_in_warpgroup^64] = li; NamedBarrier::arrive_and_wait(128, barrier_ids::WG2_SYNC); li += smem.rowwise_li_buf[idx_in_warpgroup]; if (idx_in_warpgroup < H_Q/2) { // Calculate output_scale and save int head_idx = cta_idx*(H_Q/2) + idx_in_warpgroup; float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + head_idx); float output_scale; if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { output_scale = __fdividef(1.0f, li + exp2f(fmaf(attn_sink, CUDART_L2E_F, -mi))); } else { output_scale = __fdividef(1.0f, li); } smem.rowwise_li_buf[idx_in_warpgroup] = li == 0.0f ? 0.0f : output_scale; smem.bar_li_full.arrive(); float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; if constexpr (IS_PREFILL) { int global_index = args.s_q_idx*params.h_q + head_idx; params.max_logits[global_index] = real_mi*CUDART_LN2_F; params.lse[global_index] = cur_lse; } else { if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { params.lse[args.batch_idx*params.stride_lse_b + args.s_q_idx*params.stride_lse_s_q + head_idx] = cur_lse; } else { float cur_lse_2base = log2f(li) + mi; params.lse_accum[args.n_split_idx*params.stride_lse_accum_split + args.s_q_idx*params.stride_lse_accum_s_q + head_idx] = cur_lse_2base; } } } }); } ku::barrier_cluster_arrive_relaxed(); ku::barrier_cluster_wait_acquire(); #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); } #endif } // We have two launchers with different kernel names to distinguish prefill and decode template static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) sparse_attn_fwd_for_small_topk_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) { Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); } template static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) { Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); } template void KernelTemplate::run(const ArgT& params) { static_assert(D_QK == 576 || D_QK == 512); KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings KU_ASSERT(params.h_q == H_Q); // To save some calculation KU_ASSERT(params.d_qk == D_QK); static_assert(D_Q == 512); CUtensorMap tensor_map_q; if constexpr (IS_DECODE) { KU_ASSERT(params.stride_q_b % params.stride_q_s_q == 0, "In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension)."); tensor_map_q = ku::make_tensor_map( {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.b * (params.stride_q_b / params.stride_q_s_q)}, ku::make_stride_helper({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)), {64, H_Q/2, 2, (D_Q/64)/2, 1}, params.q, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); } else { tensor_map_q = ku::make_tensor_map( {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.s_q}, ku::make_stride_helper({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)), {64, H_Q/2, 2, (D_Q/64)/2, 1}, params.q, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); // We use this layout to group Q[0:64] and Q[256:256+64] together, for UTCCP for dual gemm } CUtensorMap tensor_map_kv; CUtensorMap tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope = {}, tensor_map_extra_kv_rope = {}; if constexpr (IS_DECODE) { auto get_kv_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t stride_kv_block, int64_t stride_kv_row) -> std::pair { KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr); KU_ASSERT(stride_kv_block % TMA_K_STRIDE_FOR_DECODING == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", stride_kv_block, TMA_K_STRIDE_FOR_DECODING); CUtensorMap tensor_map_kv_nope = ku::make_tensor_map( {D_NOPE + D_ROPE*2, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)}, {TMA_K_STRIDE_FOR_DECODING}, {D_K/2, 1}, k_ptr, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B ); // NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens. CUtensorMap tensor_map_kv_rope = ku::make_tensor_map( {D_ROPE, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)}, {TMA_K_STRIDE_FOR_DECODING}, {64, 1}, (uint8_t*)k_ptr + D_NOPE, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B ); return {tensor_map_kv_nope, tensor_map_kv_rope}; }; std::tie(tensor_map_kv_nope, tensor_map_kv_rope) = get_kv_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block, params.stride_kv_row); if (params.extra_topk > 0) std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_kv_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block, params.stride_extra_kv_row); } else { tensor_map_kv = ku::make_tensor_map( {D_QK, (unsigned long)params.s_kv}, {(unsigned long)params.stride_kv_s_kv*sizeof(bf16)}, {64, 1}, params.kv, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); } CUtensorMap tensor_map_o; if constexpr (IS_DECODE) { tensor_map_o = ku::make_tensor_map( {64, H_Q, D_V/64, (unsigned long)params.s_q, (unsigned long)params.b}, ku::make_stride_helper({params.stride_o_h_q, 64, params.stride_o_s_q, params.stride_o_b}, sizeof(bf16)), {64, H_Q/2, D_V/64, 1, 1}, params.out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); } else { tensor_map_o = ku::make_tensor_map( {64, H_Q, D_V/64, (unsigned long)params.s_q, 1ul}, ku::make_stride_helper({D_V, 64, H_Q*D_V, H_Q*D_V}, sizeof(bf16)), {64, H_Q/2, D_V/64, 1, 1}, params.out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); } CUtensorMap tensor_map_o_accum = {}; if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { tensor_map_o_accum = ku::make_tensor_map( {32, H_Q, D_V/32, (unsigned long)params.s_q, (unsigned long)params.num_sm_parts + params.b}, ku::make_stride_helper({params.stride_o_accum_h_q, 32, params.stride_o_accum_s_q, params.stride_o_accum_split}, sizeof(float)), {32, H_Q/2, B_EPI_SPLITKV/32, 1, 1}, params.o_accum, CU_TENSOR_MAP_DATA_TYPE_FLOAT32, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_256B ); } TmaParams tma_params; if constexpr (IS_DECODE) { tma_params = { tensor_map_q, tensor_map_o, tensor_map_o_accum, tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope, tensor_map_extra_kv_rope }; } else { tma_params = { tensor_map_q, tensor_map_kv, tensor_map_o }; } auto kernel = IS_PREFILL ? &sparse_attn_fwd_for_small_topk_kernel> : &flash_fwd_splitkv_mla_fp8_sparse_kernel>; constexpr size_t smem_size = sizeof(SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 grid_shape; if constexpr (IS_DECODE) { grid_shape = dim3(2*params.s_q, FWD_MODE == FwdMode::DecodeWithSplitKV ? params.num_sm_parts : params.b, 1); } else { grid_shape = dim3(2*params.s_q, 1, 1); } cutlass::ClusterLaunchParams launch_params = { grid_shape, dim3(NUM_THREADS, 1, 1), dim3(2, 1, 1), smem_size, params.stream }; KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster( launch_params, (void*)kernel, params, tma_params )); } template void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT& params) { using Kernel = KernelTemplate; Kernel::run(params); } } ================================================ FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h ================================================ #pragma once #include "params.h" namespace sm100::fwd_for_small_topk::head128 { template void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT& params); } ================================================ FILE: csrc/sm90/decode/dense/config.h ================================================ #pragma once namespace Config { static constexpr int BLOCK_SIZE_M = 64; static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; } ================================================ FILE: csrc/sm90/decode/dense/instantiations/bf16.cu ================================================ #include "../splitkv_mla.cuh" #include "../splitkv_mla.h" namespace sm90 { template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/dense/instantiations/fp16.cu ================================================ #include "../splitkv_mla.cuh" #include "../splitkv_mla.h" namespace sm90 { #ifndef FLASH_MLA_DISABLE_FP16 template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); #endif } ================================================ FILE: csrc/sm90/decode/dense/splitkv_mla.cuh ================================================ #include #include "utils.h" #include "params.h" #include "config.h" #include "traits.h" using namespace cute; using cutlass::arch::NamedBarrier; namespace sm90 { // Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking // The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) // so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM static constexpr float MAX_INIT_VAL_SM = -1e30f; static constexpr float MAX_INIT_VAL = -1e33f; __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } // Launch TMA copy for a range of KV tile // A tile has a shape of PAGE_BLOCK_SIZE (64) x 64 template< int START_HEAD_DIM_TILE_IDX, int END_HEAD_DIM_TILE_IDX, typename TMA_K_OneTile, typename Engine0, typename Layout0, typename Engine1, typename Layout1 > __forceinline__ __device__ void launch_kv_tiles_copy_tma( Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled TMA_K_OneTile &tma_K, TMABarrier* barriers_K, int idx_in_warpgroup ) { if (idx_in_warpgroup == 0) { auto thr_tma = tma_K.get_slice(_0{}); Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{}); cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV); if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup); } } } // Prefetch some KV tiles // Currently this is not used because it leads to performance degradation template< int START_HEAD_DIM_TILE_IDX, int END_HEAD_DIM_TILE_IDX, typename TMA_K_OneTile, typename Engine0, typename Layout0 > __forceinline__ __device__ void prefetch_kv_tiles( Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) TMA_K_OneTile &tma_K, int idx_in_warpgroup ) { if (idx_in_warpgroup == 0) { auto thr_tma = tma_K.get_slice(_0{}); Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); cute::prefetch(tma_K, cur_gKV); if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup); } } } // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h // * Copyright (c) 2024, Tri Dao. template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } // Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64) // The Q-tile should be in shared memory template< typename TiledMMA, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void qkt_gemm_one_tile_sQ( TiledMMA &tiled_mma, Tensor const &thr_mma_sQ_tile, // (MMA, 1, 4) Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) Tensor &rP, // ((2, 2, 8), 1, 1) TMABarrier* barrier, bool &cur_phase, int idx_in_warpgroup ) { if (idx_in_warpgroup == 0) { barrier->arrive_and_expect_tx(64*64*2); } barrier->wait(cur_phase ? 1 : 0); warpgroup_fence_operand(rP); warpgroup_arrive(); cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); tiled_mma.accumulate_ = GMMA::ScaleOut::One; cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); warpgroup_commit_batch(); warpgroup_fence_operand(rP); } template< typename TiledMMA, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void qkt_gemm_one_tile_rQ( TiledMMA &tiled_mma, Tensor const &thr_mma_rQ_tile, // (MMA, 1, 4) Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) Tensor &rP, // ((2, 2, 8), 1, 1) TMABarrier* barrier, bool &cur_phase, int idx_in_warpgroup ) { if (idx_in_warpgroup == 0) { barrier->arrive_and_expect_tx(64*64*2); } barrier->wait(cur_phase ? 1 : 0); warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); warpgroup_fence_operand(rP); warpgroup_arrive(); cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); tiled_mma.accumulate_ = GMMA::ScaleOut::One; cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); warpgroup_commit_batch(); warpgroup_fence_operand(rP); warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); } // Pipelined TMA wait and Q K^T gemm // In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows: // - Wait for the 0-th tile to be ready using `barrier.wait()` // - Compute Q K^T for the 0-th tile // - Wait for the 1-st tile to be ready // - Compute Q K^T for the 1-st tile // ... // This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation template< typename T, // Traits int PHASE_IDX, // See comments in the code typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3 > __forceinline__ __device__ void warpgroup_cooperative_qkt_gemm( Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) Tensor &rP, // ((2, 2, 8), 1, 1) Tensor &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1 TMABarrier* barriers, bool &cur_phase, int idx_in_warpgroup ) { Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9) Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9) TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){}; ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup); Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9) Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9) TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){}; #define QKT_GEMM_ONE_TILE(TILE_IDX) \ if constexpr(TILE_IDX != 8) { \ qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ } else { \ qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ } if constexpr (PHASE_IDX == 0) { // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; QKT_GEMM_ONE_TILE(0); QKT_GEMM_ONE_TILE(1); QKT_GEMM_ONE_TILE(2); QKT_GEMM_ONE_TILE(3); } else if constexpr (PHASE_IDX == 1) { // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; QKT_GEMM_ONE_TILE(4); QKT_GEMM_ONE_TILE(5); QKT_GEMM_ONE_TILE(6); QKT_GEMM_ONE_TILE(7); QKT_GEMM_ONE_TILE(8); QKT_GEMM_ONE_TILE(0); QKT_GEMM_ONE_TILE(1); QKT_GEMM_ONE_TILE(2); QKT_GEMM_ONE_TILE(3); cur_phase ^= 1; } else { // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles static_assert(PHASE_IDX == 2); tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One; tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; QKT_GEMM_ONE_TILE(4); QKT_GEMM_ONE_TILE(5); QKT_GEMM_ONE_TILE(6); QKT_GEMM_ONE_TILE(7); QKT_GEMM_ONE_TILE(8); cur_phase ^= 1; } } template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline( Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K) Tensor &rP, // ((2, 2, 8), 1, 1) int idx_in_warpgroup ) { TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){}; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36) Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36) gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP); } // Compute O += PV, where P resides in register template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP( Tensor &rP, // ((2, 2, 8), 1, 1), fragment A layout Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) Tensor &rO, // ((2, 2, 32), 1, 1) int idx_in_warpgroup ) { TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){}; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor rP_retiled = make_tensor(rP.data(), Layout< Shape, _1, _4>, Stride, _0, _8> >{}); Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO); } // Compute O += PV, where P resides in shared memory template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP( Tensor &sP, Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) Tensor &rO, // ((2, 2, 32), 1, 1) int idx_in_warpgroup ) { TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){}; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP); Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO); } template< typename T, bool DO_OOB_FILLING, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3, typename Engine4, typename Layout4 > __forceinline__ __device__ void wg0_bunch_0( Tensor &rPb, // ((2, 2, 8), 1, 1) Tensor &rP0, // ((2, 2, 8), 1, 1) Tensor &rO0, // ((2, 2, 32), 1, 1) Tensor &sScale0, // (BLOCK_SIZE_M) Tensor &sM, // (BLOCK_SIZE_M) float rL[2], int rRightBorderForQSeq[2], float scale_softmax_log2, int start_token_idx, int idx_in_warpgroup ) { // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png) CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); // Mask, and get row-wise max float cur_max = MAX_INIT_VAL; CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { if constexpr (DO_OOB_FILLING) { int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL; rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL; } cur_max = max(cur_max, max(rP0(i), rP0(i+1))); } cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); // Update sM and sL cur_max *= scale_softmax_log2; float new_max = max(sM(row_idx), cur_max); float scale_for_old = exp2f(sM(row_idx) - new_max); __syncwarp(); // Make sure all reads have finished before updating sM if (idx_in_warpgroup%4 == 0) { sScale0(row_idx) = scale_for_old; sM(row_idx) = new_max; } // Scale-O CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { rO0(i) *= scale_for_old; rO0(i+1) *= scale_for_old; } // Scale, exp, and get row-wise expsum float cur_sum = 0; CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max); rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max); rPb(i) = (typename T::InputT)rP0(i); rPb(i+1) = (typename T::InputT)rP0(i+1); cur_sum += rP0(i) + rP0(i+1); } rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; } } template< typename T, bool IS_BLK0_LAST, bool IS_BLK1_LAST, bool IS_BLK2_LAST, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3, typename Engine4, typename Layout4, typename Engine5, typename Layout5 > __forceinline__ __device__ void wg1_bunch_0( Tensor &rP1b, // ((2, 2, 8), 1, 1) Tensor &sScale1, // (BLOCK_SIZE_M) Tensor &rO1, // ((2, 2, 32), 1, 1) Tensor &sM, // (BLOCK_SIZE_M) float rL[2], int rRightBorderForQSeq[2], Tensor const &sScale0, // (BLOCK_SIZE_M) Tensor &rP1, // ((2, 2, 8), 1, 1) float scale_softmax_log2, int start_token_idx, int idx_in_warpgroup ) { CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); // Mask, and get row-wise max float cur_max = MAX_INIT_VAL; CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) { // Need to apply the mask when either this block is the last one, or // the next block is the last one (because of the causal mask) int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL; rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL; } else if constexpr (IS_BLK0_LAST) { rP1(i) = rP1(i+1) = MAX_INIT_VAL; } cur_max = max(cur_max, max(rP1(i), rP1(i+1))); } cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); cur_max *= scale_softmax_log2; float old_max = sM(row_idx); float new_max = max(old_max, cur_max); float scale_for_old = exp2f(old_max - new_max); __syncwarp(); if (idx_in_warpgroup%4 == 0) { sM(row_idx) = new_max; sScale1(row_idx) = scale_for_old; } // Scale, exp, and get row-wise expsum float cur_sum = 0; if constexpr (!IS_BLK0_LAST) { CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max); rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max); rP1b(i) = (typename T::InputT)rP1(i); rP1b(i+1) = (typename T::InputT)rP1(i+1); cur_sum += rP1(i) + rP1(i+1); } } // Scale O float cur_scale_for_o1 = scale_for_old * sScale0(row_idx); CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) { rO1(i) *= cur_scale_for_o1; rO1(i+1) *= cur_scale_for_o1; } // Update rL rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum; } } // Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1 > __forceinline__ __device__ void save_rPb_to_sP( Tensor &rPb, Tensor &sP, int idx_in_warpgroup ) { auto r2s_copy = make_tiled_copy_C( Copy_Atom{}, (typename T::TiledMMA_QK_sQ){} ); ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); Tensor thr_copy_rPb = thr_copy.retile_S(rPb); Tensor thr_copy_sP = thr_copy.partition_D(sP); cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); } // Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1 > __forceinline__ __device__ void retrieve_rP_from_sP( Tensor &rPb, Tensor const &sP, int idx_in_warpgroup ) { TiledCopy s2r_copy = make_tiled_copy_A( Copy_Atom{}, (typename T::TiledMMA_PV_LocalP){} ); ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); Tensor thr_copy_sP = thr_copy.partition_S(sP); Tensor thr_copy_rPb = thr_copy.retile_D(rPb); cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); } // Rescale rP0 and save the result to rPb template< typename T, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > __forceinline__ __device__ void wg0_scale_rP0( Tensor const &sScale1, // (BLOCK_M) Tensor const &rP0, // ((2, 2, 8), 1, 1) Tensor &rPb, // ((2, 2, 8), 1, 1) int idx_in_warpgroup ) { CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); float scale_factor = sScale1(row_idx); CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { rPb(i) = (typename T::InputT)(rP0(i)*scale_factor); rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor); } } } // Rescale rO0 according to sScale1 template< typename Engine0, typename Layout0, typename Engine1, typename Layout1 > __forceinline__ __device__ void wg0_rescale_rO0( Tensor &rO0, Tensor &sScale1, float rL[2], int idx_in_warpgroup ) { CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); float scale_factor = sScale1(row_idx); CUTLASS_PRAGMA_UNROLL for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { rO0(i) *= scale_factor; rO0(i+1) *= scale_factor; } rL[local_row_idx] *= scale_factor; } } // Fill out-of-bound V with 0.0 // We must fill it since it may contain NaN, which may propagate to the final result template< typename T, typename Engine0, typename Layout0 > __forceinline__ __device__ void fill_oob_V( Tensor &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} ); int valid_window_size, int idx_in_warpgroup ) { Tensor sV_int64 = make_tensor( make_smem_ptr((int64_t*)(sV.data().get().get())), tile_to_shape( GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} ) ); valid_window_size = max(valid_window_size, 0); int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) { sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0; } } // Store O / OAccum template< typename T, bool IS_NO_SPLIT, typename TMAParams, typename Engine0, typename Layout0, typename Engine1, typename Layout1 > __forceinline__ __device__ void store_o( Tensor &rO, // ((2, 2, 32), 1, 1) Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) float rL[2], char* sO_addr, TMAParams &tma_params, int batch_idx, int k_head_idx, int m_block_idx, int num_valid_seq_q, int warpgroup_idx, int idx_in_warpgroup ) { using InputT = typename T::InputT; if constexpr (IS_NO_SPLIT) { // Should convert the output to bfloat16 / float16, and save it to O Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); Tensor rOb = make_tensor_like(rO); CUTLASS_PRAGMA_UNROLL for (int idx = 0; idx < size(rO); ++idx) { rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]); } Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); TiledCopy r2s_tiled_copy = make_tiled_copy_C( Copy_Atom{}, (typename T::TiledMMA_PV_LocalP){} ); ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); cutlass::arch::fence_view_async_shared(); __syncthreads(); if (threadIdx.x == 0) { Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) auto thr_tma = tma_params.tma_O.get_slice(_0{}); Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{}); cute::copy( tma_params.tma_O, thr_tma.partition_S(sOutputBuf), thr_tma.partition_D(my_tma_gO) ); cute::tma_store_arrive(); } } else { // Should save the result to OAccum Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout< Shape<_64, _512>, Stride, _1> // We use stride = 520 here to avoid bank conflict >{}); CUTLASS_PRAGMA_UNROLL for (int idx = 0; idx < size(rO); idx += 2) { int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 { rO(idx) / rL[idx%4 >= 2], rO(idx+1) / rL[idx%4 >= 2], }; } cutlass::arch::fence_view_async_shared(); __syncthreads(); int row = threadIdx.x; if (row < num_valid_seq_q) { SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float)); cute::tma_store_arrive(); } } } template< typename T, typename TmaParams, typename Tensor0 > __forceinline__ __device__ void launch_q_copy( TmaParams const &tma_params, int batch_idx, int m_block_idx, int k_head_idx, Tensor0 &sQ, TMABarrier* barrier_Q ) { if (threadIdx.x == 0) { Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) auto thr_tma = tma_params.tma_Q.get_slice(_0{}); Tensor my_tma_gQ = flat_divide(tma_gQ, Shape, Int>{})(_, _, m_block_idx, _0{}); cute::copy( tma_params.tma_Q.with(reinterpret_cast(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), thr_tma.partition_S(my_tma_gQ), thr_tma.partition_D(sQ) ); barrier_Q->arrive_and_expect_tx(64*576*2); } } template< typename T, bool IS_R, typename Engine0, typename Layout0 > __forceinline__ __device__ auto get_half_V( Tensor &sK ) { Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){}); return flat_divide(sV, Shape, Int>{})(_, _, Int<(int)IS_R>{}, _0{}); } template< typename T, bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ... bool IS_BLK1_LAST, typename TMAParams, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3, typename Engine4, typename Layout4, typename Engine5, typename Layout5, typename Engine6, typename Layout6, typename Engine7, typename Layout7, typename Engine8, typename Layout8, typename Engine9, typename Layout9, typename Engine10, typename Layout10, typename Engine11, typename Layout11 > __forceinline__ __device__ void wg0_subroutine( Tensor &tma_gK, Tensor &sQ, Tensor &sK0, Tensor &sK1, Tensor &sP0, Tensor &sP1, Tensor &sM, Tensor &sScale0, Tensor &sScale1, Tensor &rQ8, Tensor &rP0, Tensor &rO0, float rL[2], int rRightBorderForQSeq[2], TMABarrier barriers_K0[9], TMABarrier barriers_K1[9], bool &cur_phase_K0, const TMAParams &tma_params, const DenseAttnDecodeParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, int end_block_idx, int idx_in_warpgroup ) { int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx))) int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); Tensor sV0L = get_half_V(sK0); Tensor sV1L = get_half_V(sK1); Tensor rPb = make_tensor(Shape, _1, _4>{}); // Calc P0 = softmax(P0) wg0_bunch_0(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup); NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready); // Issue rO0 += rPb @ sV0L if constexpr (IS_BLK0_LAST) { fill_oob_V(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup); cutlass::arch::fence_view_async_shared(); } warpgroup_cooperative_pv_gemm_localP(rPb, sV0L, rO0, idx_in_warpgroup); // Wait for rO0, launch TMA for the next V0L cute::warpgroup_wait<0>(); // Wait for warpgroup 1, rescale P0, notify warpgroup 1 NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready); if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { // Put it here seems to be faster, don't know why launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); } wg0_scale_rP0(sScale1, rP0, rPb, idx_in_warpgroup); save_rPb_to_sP(rPb, sP0, idx_in_warpgroup); cutlass::arch::fence_view_async_shared(); NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready); // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L if constexpr (!IS_BLK0_LAST) { if constexpr (IS_BLK1_LAST) { fill_oob_V(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); cutlass::arch::fence_view_async_shared(); } NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup); warpgroup_cooperative_pv_gemm_remoteP(sP1, sV1L, rO0, idx_in_warpgroup); } // Issue P0 = Q @ K0^T // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline. if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); } // Wait for rO0 += rPb @ sV1L, launch TMA if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) { cute::warpgroup_wait<4>(); launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); } // Issue P0 = Q @ K0^T if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); } // Wait for P0 = Q @ K0^T cute::warpgroup_wait<0>(); } template< typename T, bool IS_BLK0_LAST, bool IS_BLK1_LAST, bool IS_BLK2_LAST, typename TMAParams, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3, typename Engine4, typename Layout4, typename Engine5, typename Layout5, typename Engine6, typename Layout6, typename Engine7, typename Layout7, typename Engine8, typename Layout8, typename Engine9, typename Layout9, typename Engine10, typename Layout10, typename Engine11, typename Layout11 > __forceinline__ __device__ void wg1_subroutine( Tensor &tma_gK, Tensor &sQ, Tensor &sK0, Tensor &sK1, Tensor &sP0, Tensor &sP1, Tensor &sM, Tensor &sScale0, Tensor &sScale1, Tensor &rQ8, Tensor &rP1, Tensor &rO1, float rL[2], int rRightBorderForQSeq[2], TMABarrier barriers_K0[9], TMABarrier barriers_K1[9], bool &cur_phase_K1, const TMAParams &tma_params, const DenseAttnDecodeParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, int end_block_idx, int idx_in_warpgroup ) { int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); Tensor rP1b = make_tensor(Shape, _1, _4>{}); Tensor sV0R = get_half_V(sK0); Tensor sV1R = get_half_V(sK1); // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0 NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready); wg1_bunch_0(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup); NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready); // Save rPb to sP, and issue rO1 += rP1b @ sV1R // We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations if constexpr (!IS_BLK0_LAST) { save_rPb_to_sP(rP1b, sP1, idx_in_warpgroup); } if constexpr (!IS_BLK0_LAST) { if constexpr (IS_BLK1_LAST) { fill_oob_V(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); cutlass::arch::fence_view_async_shared(); } warpgroup_cooperative_pv_gemm_localP(rP1b, sV1R, rO1, idx_in_warpgroup); if constexpr (!IS_BLK1_LAST) { // We use this proxy for making sP1 visible to the async proxy // We skip it if IS_BLK1_LAST, since in that case we have already put a fence cutlass::arch::fence_view_async_shared(); } } // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0 NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready); if constexpr (IS_BLK0_LAST) { fill_oob_V(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup); cutlass::arch::fence_view_async_shared(); } warpgroup_cooperative_pv_gemm_remoteP(sP0, sV0R, rO1, idx_in_warpgroup); if constexpr (!IS_BLK0_LAST) { NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); } // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { cute::warpgroup_wait<1>(); launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); } // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { cute::warpgroup_wait<0>(); launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); } if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { // Issue rP1 = sQ @ sK1, wait warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); } // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise // nvcc cannot correctly analyse the loop, and will think that we are using accumulator // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized. // This is also the reason why we put QK^T here, instead of the first operation in the loop cute::warpgroup_wait<0>(); } // A helper function for determining the length of the causal mask for one q token __forceinline__ __device__ int get_mask_len(const DenseAttnDecodeParams ¶ms, int m_block_idx, int local_seq_q_idx) { int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; if (global_seq_q_idx < params.q_seq_per_hk) { int s_q_idx = global_seq_q_idx / params.q_head_per_hk; return params.s_q - s_q_idx - 1; } else { // Out-of-bound request, regard as no masks return 0; } } template __global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) flash_fwd_splitkv_mla_kernel(__grid_constant__ const DenseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) { // grid shape: [ // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), // num_kv_heads, // num_sm_parts // ] // An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx). // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int m_block_idx = blockIdx.x; const int k_head_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int warpgroup_idx = threadIdx.x / 128; const int idx_in_warpgroup = threadIdx.x % 128; // Define shared tensors extern __shared__ char wksp_buf[]; using SharedMemoryPlan = typename T::SharedMemoryPlan; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){}); Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){}); Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){}); Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){}); Tensor sP1 = flat_divide(sQ, Shape, Int>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int{})); Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{})); Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int{})); Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int{})); char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1 // Prefetch TMA descriptors if (threadIdx.x == 0) { cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); } // Define TMA stuffs Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _); TMABarrier* barriers_K0 = plan.barriers_K0; TMABarrier* barriers_K1 = plan.barriers_K1; TMABarrier* barrier_Q = &(plan.barrier_Q); // Initialize TMA barriers if (threadIdx.x == 0) { barrier_Q->init(1); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 9; ++i) { barriers_K0[i].init(1); barriers_K1[i].init(1); } cutlass::arch::fence_view_async_shared(); } __syncthreads(); bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; if (sched_meta.begin_req_idx >= params.b) return; // Copy the first Q launch_q_copy(tma_params, sched_meta.begin_req_idx, m_block_idx, k_head_idx, sQ, barrier_Q); #pragma unroll 1 for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN); const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); int rRightBorderForQSeq[2]; if (params.is_causal) { // The causal mask looks like: // XXXX // XXXX // ... // XXXX // XXX // XXX // ... // XXX // XX // XX // ... // XX // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation. // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN); end_block_idx = batch_idx == sched_meta.end_req_idx ? min(sched_meta.end_block_idx, last_block_in_seq) : last_block_in_seq; CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE); } } else { rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k; } // Define global tensors using InputT = typename T::InputT; InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1) float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1) Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( Shape, Int>{}, make_stride(params.o_row_stride, _1{}) )); Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout< Shape>, Stride<_1> >{}); // Copy K0 and K1 launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); if (start_block_idx+1 < end_block_idx) { launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); } Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) float rL[2]; rL[0] = rL[1] = 0.0f; // Clear buffers cute::fill(rO, 0.); if (threadIdx.x < size(sM)) { sM[threadIdx.x] = MAX_INIT_VAL_SM; } // Wait for Q barrier_Q->wait(cur_phase_Q); cur_phase_Q ^= 1; Tensor rQ8 = make_tensor(Shape, _1, _4>{}); retrieve_rP_from_sP(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup); if (warpgroup_idx == 0) { // Warpgroup 0 Tensor rP0 = make_tensor((typename T::rP0Layout){}); // NOTE We don't use the pipelined version of Q K^T here since it leads // to a slow-down (or even register spilling, thanks to the great NVCC) // Wait for K0 CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 9; ++i) { if (idx_in_warpgroup == 0) barriers_K0[i].arrive_and_expect_tx(64*64*2); barriers_K0[i].wait(cur_phase_K0); } cur_phase_K0 ^= 1; // Issue P0 = Q @ K0^T, wait if (start_block_idx-16777216 < end_block_idx) { // NOTE We use this `if` to prevent register spilling warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); } // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); cute::warpgroup_wait<0>(); #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ wg0_subroutine( \ tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ rQ8, rP0, rO, rL, rRightBorderForQSeq, \ barriers_K0, barriers_K1, cur_phase_K0, \ tma_params, params, \ block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ ); int block_idx = start_block_idx; #pragma unroll 1 for (; block_idx < end_block_idx-2; block_idx += 2) { LAUNCH_WG0_SUBROUTINE(false, false); } if (block_idx+1 < end_block_idx) { LAUNCH_WG0_SUBROUTINE(false, true); } else if (block_idx < end_block_idx) { LAUNCH_WG0_SUBROUTINE(true, false); } } else { // Warpgroup 1 Tensor rP1 = make_tensor((typename T::rP0Layout){}); if (start_block_idx+1 < end_block_idx) { // Issue rP1 = sQ @ sK1, wait warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); cute::warpgroup_wait<0>(); } #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \ wg1_subroutine( \ tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ rQ8, rP1, rO, rL, rRightBorderForQSeq, \ barriers_K0, barriers_K1, cur_phase_K1, \ tma_params, params, \ block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ ); int block_idx = start_block_idx; #pragma unroll 1 for (; block_idx < end_block_idx-3; block_idx += 2) { LAUNCH_WG1_SUBROUTINE(false, false, false); } if (block_idx+2 < end_block_idx) { LAUNCH_WG1_SUBROUTINE(false, false, true); block_idx += 2; LAUNCH_WG1_SUBROUTINE(true, false, false); } else if (block_idx+1 < end_block_idx) { LAUNCH_WG1_SUBROUTINE(false, true, false); } else if (block_idx < end_block_idx) { LAUNCH_WG1_SUBROUTINE(true, false, false); } } // Reduce rL across threads within the same warp rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); // Reduce rL across warpgroups int my_row = get_AorC_row_idx(0, idx_in_warpgroup); if (idx_in_warpgroup%4 == 0) { sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0]; sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1]; } __syncthreads(); if (warpgroup_idx == 0) { rL[0] += sL_reduction_wksp[my_row + 64]; rL[1] += sL_reduction_wksp[my_row + 8 + 64]; } else { if (idx_in_warpgroup%4 == 0) { sL_reduction_wksp[my_row] += rL[0]; sL_reduction_wksp[my_row + 8] += rL[1]; } __syncwarp(); rL[0] = sL_reduction_wksp[my_row]; rL[1] = sL_reduction_wksp[my_row+8]; } // Prune out when rL is 0.0f or NaN // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error. // When this happens, we set rL to 1.0f. This aligns with the old version // of the MLA kernel. CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 2; ++i) rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; // Copy Q for the next batch if (batch_idx+1 <= sched_meta.end_req_idx) { launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); } else { // Allow the next kernel (the combine kernel) to launch // The next kernel MUST be the combine kernel cudaTriggerProgrammaticLaunchCompletion(); } int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M); if (is_no_split) { store_o(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); int i = threadIdx.x; if (i < num_valid_seq_q) { float cur_L = sL_reduction_wksp[i]; gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E; } cute::tma_store_wait<0>(); } else { // Don't use __ldg because of PDL and instruction reordering int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< Shape, Int>, Stride, _1> >{}); Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout< Shape>, Stride<_1> >{}); store_o(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); int i = threadIdx.x; if (i < num_valid_seq_q) { float cur_L = sL_reduction_wksp[i]; gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i); } cute::tma_store_wait<0>(); } if (batch_idx != sched_meta.end_req_idx) __syncthreads(); } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); } #endif } template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) { FLASH_ASSERT(params.d == Config::HEAD_DIM_K); FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); using T = Traits; auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((InputT*)params.q_ptr), make_layout( shape_Q, make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride) ) ), tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} ) ); auto shape_K = make_shape(Int{}, Int{}, params.h_k, params.num_blocks); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((InputT*)params.k_ptr), make_layout( shape_K, make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride) ) ), tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Layout< Shape, Int<64>>, Stride, _1> >{} ) ); auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b); auto tma_O = cute::make_tma_copy( SM90_TMA_STORE{}, make_tensor( make_gmem_ptr((InputT*)params.o_ptr), make_layout( shape_O, make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride) ) ), tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} ) ); TmaParams tma_params = { shape_Q, tma_Q, shape_K, tma_K, shape_O, tma_O }; auto mla_kernel = &flash_fwd_splitkv_mla_kernel; constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); cudaLaunchAttribute mla_kernel_attributes[1]; mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; cudaLaunchConfig_t mla_kernel_config = { dim3(num_m_block, params.h_k, params.num_sm_parts), dim3(T::NUM_THREADS, 1, 1), smem_size, params.stream, mla_kernel_attributes, 1 }; cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); CHECK_CUDA_KERNEL_LAUNCH(); } } ================================================ FILE: csrc/sm90/decode/dense/splitkv_mla.h ================================================ #pragma once #include "params.h" namespace sm90 { template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/dense/traits.h ================================================ #pragma once #include #include #include #include #include "config.h" using TMABarrier = cutlass::arch::ClusterTransactionBarrier; using namespace cute; template struct Traits { using InputT = InputT_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v || std::is_same_v); using TiledMMA_QK_sQ = decltype(make_tiled_mma( GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), Layout>{} )); using TiledMMA_QK_rQ = decltype(make_tiled_mma( GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), Layout>{} )); using TiledMMA_PV_LocalP = decltype(make_tiled_mma( GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), Layout>{} )); using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), Layout>{} )); using SmemLayoutQ = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using SmemLayoutK = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using SmemLayoutV = decltype(composition( SmemLayoutK{}, make_layout(Shape, Int>{}, GenRowMajor{}) )); // A transposed version of SmemLayoutK using SmemLayoutP0 = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using rP0Layout = decltype(layout(partition_fragment_C( TiledMMA_QK_sQ{}, Shape, Int>{} ))); struct SharedMemoryPlan { cute::array_aligned> smem_sQ; cute::array_aligned> smem_sK0; cute::array_aligned> smem_sK1; cute::array_aligned> smem_sP0; cute::array_aligned smem_sM; cute::array_aligned sL_reduction_wksp; cute::array_aligned smem_sScale0; cute::array_aligned smem_sScale1; TMABarrier barriers_K0[HEAD_DIM_K/64]; TMABarrier barriers_K1[HEAD_DIM_K/64]; TMABarrier barrier_Q; }; }; template< typename ShapeQ, typename TMA_Q, typename ShapeK, typename TMA_K, typename ShapeO, typename TMA_O > struct TmaParams { ShapeQ shape_Q; TMA_Q tma_Q; ShapeK shape_K; TMA_K tma_K; ShapeO shape_O; TMA_O tma_O; }; enum NamedBarriers : int { sScale0Ready = 0, sScale1Ready = 1, sP0Ready = 2, rO1sP0sV0RIssued = 3, sMInitialized = 4, }; ================================================ FILE: csrc/sm90/decode/sparse_fp8/components/config.h ================================================ #pragma once #include #include #include #include "defines.h" using namespace cute; namespace sm90::decode::sparse_fp8 { static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V; static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V; static constexpr int QUANT_TILE_SIZE = 128; static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE; static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16); static constexpr int PAGE_BLOCK_SIZE = 64; } ================================================ FILE: csrc/sm90/decode/sparse_fp8/components/dequant.h ================================================ #pragma once #include #include #include "defines.h" namespace sm90::decode::sparse_fp8 { struct fp8x8 { __nv_fp8x4_e4m3 lo; __nv_fp8x4_e4m3 hi; }; struct fp8x16 { fp8x8 lo; fp8x8 hi; }; __device__ __forceinline__ bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) { #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ { \ float4 fp32x4 = (float4)(FP8x4); \ OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ } bf16x8 result; DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); return result; } enum class L1CacheHint { NO_ALLOCATE, EVICT_FIRST, EVICT_NORMAL, EVICT_LAST }; enum class L2PrefetchHint { B64, B128, B256 }; template< typename T, L1CacheHint l1_cache_hint, L2PrefetchHint l2_prefetch_hint > __device__ __forceinline__ T load_128b_from_gmem(const void* addr) { static_assert(sizeof(T) == 128/8); int4 ret; #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \ : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \ : "l"(addr)); \ } #define DISPATCH_L2(L1_HINT_STR) { \ if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ EXEC(L1_HINT_STR, "64B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ EXEC(L1_HINT_STR, "128B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ EXEC(L1_HINT_STR, "256B") \ } if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) DISPATCH_L2("no_allocate") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) DISPATCH_L2("evict_first") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) DISPATCH_L2("evict_normal") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) DISPATCH_L2("evict_last") #undef EXEC #undef DISPATCH_L2 return *reinterpret_cast(&ret); } template< typename T, L1CacheHint l1_cache_hint, L2PrefetchHint l2_prefetch_hint > __device__ __forceinline__ T load_64b_from_gmem(const void* addr) { static_assert(sizeof(T) == 64/8); int2 ret; #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v2.s32 {%0, %1}, [%2];" \ : "=r"(ret.x), "=r"(ret.y) \ : "l"(addr)); \ } #define DISPATCH_L2(L1_HINT_STR) { \ if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ EXEC(L1_HINT_STR, "64B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ EXEC(L1_HINT_STR, "128B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ EXEC(L1_HINT_STR, "256B") \ } if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) DISPATCH_L2("no_allocate") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) DISPATCH_L2("evict_first") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) DISPATCH_L2("evict_normal") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) DISPATCH_L2("evict_last") #undef EXEC #undef DISPATCH_L2 return *reinterpret_cast(&ret); } } ================================================ FILE: csrc/sm90/decode/sparse_fp8/components/helpers.h ================================================ #pragma once #include #include #include "config.h" using namespace cute; namespace sm90::decode::sparse_fp8 { // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } template< typename TMA, typename Tensor0, typename Tensor1 > CUTE_DEVICE void launch_tma_copy( const TMA &tma_copy, const Tensor0 &src, Tensor1 &dst, transac_bar_t &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL, const uint16_t &multicast_mask = 0 ) { auto thr_tma = tma_copy.get_slice(_0{}); cute::copy( tma_copy.with(reinterpret_cast(bar), multicast_mask, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); } template CUTE_DEVICE static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { long2 data_long2 = *reinterpret_cast(&data); uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); asm volatile ( "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" : : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) ); } CUTE_DEVICE static void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr); uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); asm volatile ( "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" : : "r"(dst_addr), "r"(src_addr), "r"(size), "r"(mbar_addr) ); } static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. template CUTE_DEVICE T* get_peer_addr(T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); } } ================================================ FILE: csrc/sm90/decode/sparse_fp8/config.h ================================================ #pragma once #include #include #include #include #include "defines.h" #include "params.h" using namespace cute; namespace sm90::decode::sparse_fp8 { template class KernelTemplate { public: static_assert(NUM_HEADS == 64 || NUM_HEADS == 128); static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64; static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS; static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512; static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_ROPE = 64; static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE; static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding static constexpr int NUM_THREADS = 128*3; static constexpr int BLOCK_M = 64; static constexpr int TOPK_BLOCK_SIZE = 64; static constexpr int NUM_K_BUFS = 2; using SmemLayoutQTile = decltype(tile_to_shape( GMMA::Layout_SW128_Atom{}, Shape, Int<64>>{} )); template using SmemLayoutQTiles = decltype(tile_to_shape( SmemLayoutQTile{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} )); using SmemLayoutQ = SmemLayoutQTiles; using SmemLayoutKTile = decltype(tile_to_shape( GMMA::Layout_INTER_Atom{}, Shape, _64>{}, Step<_1, _2>{} )); template using SmemLayoutKTiles = decltype(tile_to_shape( SmemLayoutKTile{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} )); template using SmemLayoutKTilesTransposed = decltype(composition( SmemLayoutKTiles{}, Layout, Int>, Stride, _1>>{} )); static constexpr int OBUF_SW = 64; using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom; using SmemLayoutOBuf = decltype(tile_to_shape( SmemLayoutOBufAtom{}, Shape, Int>{}, Step<_1, _2>{} )); using SmemLayoutOAccumBuf = Layout< Shape, Int>, Stride, _1> // We use stride = 520 here to avoid bank conflict >; using SmemLayoutK = SmemLayoutKTiles; using SmemLayoutV = SmemLayoutKTilesTransposed; using SmemLayoutHalfV = SmemLayoutKTilesTransposed; using SmemLayoutS = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); struct SharedMemoryPlan { array_aligned> q; union { array_aligned> k[NUM_K_BUFS]; array_aligned> oBuf; array_aligned> oAccumBuf; } u; CUTE_ALIGNAS(1024) array_aligned> s; bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M]; transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; }; template< typename Shape_Q, typename TMA_Q > struct TmaParams { Shape_Q shape_Q; TMA_Q tma_Q; CUtensorMap tensor_map_o; }; using TiledMMA_QK = decltype(make_tiled_mma( GMMA::MMA_64x64x16_F32BF16BF16_SS{}, Layout>{} )); using TiledMMA_QK_rQ = decltype(make_tiled_mma( GMMA::MMA_64x64x16_F32BF16BF16_RS{}, Layout>{} )); using TiledMMA_PV_LocalP = decltype(make_tiled_mma( GMMA::MMA_64x256x16_F32BF16BF16_RS{}, Layout>{} )); using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( GMMA::MMA_64x256x16_F32BF16BF16_SS{}, Layout>{} )); enum NamedBarriers : uint32_t { sScale_and_sS_ready = 0, sScale_and_sS_free = 1, oBuf_free_and_sL_ready = 2, epilogue_r2s_ready = 3, batch_loop_sync = 4, warpgroup0_sync = 5 }; // Synchronize all threads within the cluster (which processes one q token) static __forceinline__ __device__ void sync_all_threads_in_cluster() { if constexpr (CLUSTER_SIZE == 1) { __syncthreads(); } else { ku::barrier_cluster_arrive_relaxed(); ku::barrier_cluster_wait_acquire(); } } // Save rPb (64x64, bfloat16) to sP using the stmatrix instruction template< typename Tensor0, typename Tensor1 > static __forceinline__ __device__ void save_rPb_to_sP( Tensor0 const &rPb, Tensor1 const &sP, int idx_in_warpgroup ) { auto r2s_copy = make_tiled_copy_C( Copy_Atom{}, TiledMMA_QK{} ); ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); Tensor thr_copy_rPb = thr_copy.retile_S(rPb); Tensor thr_copy_sP = thr_copy.partition_D(sP); cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); } template< bool IS_NO_SPLIT, typename TMAParams, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3 > static __forceinline__ __device__ void store_o( Tensor0 &rO, // ((2, 2, 32), 1, 1) Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) Tensor2 &sOutputBuf, Tensor3 &sOutputAccumBuf, SharedMemoryPlan &plan, float o_scales[2], TMAParams &tma_params, int batch_idx, int s_q_idx, int head_block_idx, int num_valid_seq_q, int warpgroup_idx, int idx_in_warpgroup ) { using cutlass::arch::NamedBarrier; if constexpr (IS_NO_SPLIT) { // Should convert the output to bfloat16 / float16, and save it to O // Here we don't pipeline STSM and tma store because it's slower Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); // Calculate "base" ptrs in advance // Each STSM fills a chunk of shape 16x16, while we are using SW-OBUF_SW, so we need OBUF_SW/16 base pointers constexpr int NUM_CHUNKS_IN_SW_ATOM = OBUF_SW/16; bf16* base_output_buf_ptrs[NUM_CHUNKS_IN_SW_ATOM]; CUTE_UNROLL for (int i = 0; i < NUM_CHUNKS_IN_SW_ATOM; ++i) { base_output_buf_ptrs[i] = &sMyOutputBuf((idx_in_warpgroup/32)*16+idx_in_warpgroup%16, idx_in_warpgroup%32/16*8 + i*16); } CUTE_UNROLL for (int idx = 0; idx < (HEAD_DIM_V/2)/16; idx += 1) { // In each iteration we deal with a chunk of shape 16x16 using bf16x2 = __nv_bfloat162; bf16x2 a01 = __float22bfloat162_rn(float2{rO(idx*8+0)*o_scales[0], rO(idx*8+1)*o_scales[0]}); bf16x2 a23 = __float22bfloat162_rn(float2{rO(idx*8+2)*o_scales[1], rO(idx*8+3)*o_scales[1]}); bf16x2 a45 = __float22bfloat162_rn(float2{rO(idx*8+4)*o_scales[0], rO(idx*8+5)*o_scales[0]}); bf16x2 a67 = __float22bfloat162_rn(float2{rO(idx*8+6)*o_scales[1], rO(idx*8+7)*o_scales[1]}); SM90_U32x4_STSM_N::copy( *reinterpret_cast(&a01), *reinterpret_cast(&a23), *reinterpret_cast(&a45), *reinterpret_cast(&a67), *reinterpret_cast(base_output_buf_ptrs[idx%4] + (idx/4*4)*16*64) ); } cutlass::arch::fence_view_async_shared(); NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); if (threadIdx.x == 0) { SM90_TMA_STORE_5D::copy( &tma_params.tensor_map_o, plan.u.oBuf.data(), 0, head_block_idx*64, 0, s_q_idx, batch_idx ); cute::tma_store_arrive(); } } else { // Should save the result to OAccum CUTLASS_PRAGMA_UNROLL for (int idx = 0; idx < size(rO); idx += 2) { int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; *(float2*)(&(sOutputAccumBuf(row, col))) = float2 { rO(idx) * o_scales[idx%4>=2], rO(idx+1) * o_scales[idx%4>=2], }; } cutlass::arch::fence_view_async_shared(); NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); if (elect_one_sync()) { CUTLASS_PRAGMA_UNROLL for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) { int row = local_row * (256/32) + (threadIdx.x / 32); if (row < num_valid_seq_q) { SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float)); } } cute::tma_store_arrive(); } } } template static __device__ __forceinline__ void devfunc(const SparseAttnDecodeParams ¶ms, const TMAParams &tma_params); static void run(const SparseAttnDecodeParams ¶ms); }; } ================================================ FILE: csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu ================================================ #include "../splitkv_mla.cuh" namespace sm90::decode::sparse_fp8 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu ================================================ #include "../splitkv_mla.cuh" namespace sm90::decode::sparse_fp8 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu ================================================ #include "../splitkv_mla.cuh" namespace sm90::decode::sparse_fp8 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu ================================================ #include "../splitkv_mla.cuh" namespace sm90::decode::sparse_fp8 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh ================================================ #pragma once #include "splitkv_mla.h" #include #include #include #include #include #include #include #include "utils.h" #include "components/dequant.h" #include "components/helpers.h" #include "config.h" using namespace cute; namespace sm90::decode::sparse_fp8 { static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan using cutlass::arch::fence_view_async_shared; using cutlass::arch::NamedBarrier; using fp8_e8m0 = __nv_fp8_e8m0; template< typename Tensor0, typename Tensor1, typename Tensor2 > __forceinline__ __device__ void scale_softmax( Tensor0 &rP, Tensor1 &rS, Tensor2 &rO, float scale_softmax_log2, float sScale[], float rM[2], float rL[2], bool is_kv_valid[], int block_idx, int idx_in_warpgroup ) { float scale_for_olds[2]; CUTE_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _)); Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _)); Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); float cur_max = -INFINITY; CUTE_UNROLL for (int i = 0; i < size(cur_rP); ++i) { if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2]) cur_rP(i) = -INFINITY; cur_max = max(cur_max, cur_rP(i)); } cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); cur_max *= scale_softmax_log2; float old_max = rM[local_row_idx]; rM[local_row_idx] = max(cur_max, old_max); float scale_for_old = exp2f(old_max - rM[local_row_idx]); scale_for_olds[local_row_idx] = scale_for_old; CUTE_UNROLL for (int i = 0; i < size(cur_rO); ++i) { cur_rO(i) *= scale_for_old; } float cur_sum = 0; CUTE_UNROLL for (int i = 0; i < size(cur_rP); ++i) { cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]); cur_rS(i) = (bf16)cur_rP(i); cur_sum += cur_rP(i); } rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; } if (idx_in_warpgroup%4 == 0) *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds); } template template __device__ void KernelTemplate::devfunc(const SparseAttnDecodeParams ¶ms, const TMAParams &tma_params) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x; const int s_q_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int idx_in_cluster = CLUSTER_SIZE == 1 ? 0 : head_block_idx % 2; const int warpgroup_idx = cutlass::canonical_warp_group_idx(); const int idx_in_warpgroup = threadIdx.x % 128; const int warp_idx = cutlass::canonical_warp_idx_sync(); // Define shared tensors extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{}); Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{}); Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); float* sM = plan.sM; float* sL = plan.sL; float* sScale = plan.sScale; // Prefetch TMA descriptors if (warp_idx == 0 && elect_one_sync()) { cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(&tma_params.tensor_map_o); } // Initialize TMA barriers if (warp_idx == 0 && elect_one_sync()) { plan.bar_q.init(1); if constexpr (CLUSTER_SIZE == 2) { CUTE_UNROLL for (int i = 0; i < NUM_K_BUFS; ++i) { plan.bar_k_local_ready[i].init(128); plan.bar_k_remote_ready[i].init(1); plan.bar_k_avail[i].init(4); } } else { CUTE_UNROLL for (int i = 0; i < NUM_K_BUFS; ++i) { plan.bar_k_local_ready[i].init(128); plan.bar_k_avail[i].init(256); } } cutlass::arch::fence_barrier_init(); } ku::barrier_cluster_arrive_relaxed(); int bar_phase_k = 0; // Don't use array here to prevent using local memory // Programmatic Dependent Launch: Wait for the previous kernel to finish // Don't use PDL because of compiler bugs! // cudaGridDependencySynchronize(); DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; if (sched_meta.begin_req_idx >= params.b) return; if (warp_idx == 0 && elect_one_sync()) { Tensor gQ = flat_divide( tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, sched_meta.begin_req_idx), Tile, Int>{} )(_, _, head_block_idx, _0{}); launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); } ku::barrier_cluster_wait_acquire(); struct MainloopArgs { int start_block_idx, end_block_idx; bool is_no_split; // The following fields are only valid for MODEL1 int topk_length, extra_topk_length, num_orig_kv_blocks; }; auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs { MainloopArgs args; int total_topk_padded; if constexpr (MODEL_TYPE == ModelType::V32) { total_topk_padded = params.topk; } else { int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE); int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE); args.topk_length = topk_length; args.extra_topk_length = extra_topk_length; args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE; } args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE; args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); return args; }; if (warpgroup_idx == 0) { cutlass::arch::warpgroup_reg_alloc<192>(); TiledMMA tiled_mma_QK = TiledMMA_QK{}; ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup); TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{}; ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); float rL[2], rM[2]; Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); float rAttn_sink[2] = {-CUDART_INF_F, -CUDART_INF_F}; if (params.attn_sink != nullptr) { for (int i = 0; i < 2; ++i) { int head_idx = head_block_idx*BLOCK_M + get_AorC_row_idx(i, idx_in_warpgroup); rAttn_sink[i] = __ldg((float*)params.attn_sink + head_idx) * CUDART_L2E_F; } } #pragma unroll 1 for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { MainloopArgs args = get_cur_req_info(batch_idx); rL[0] = rL[1] = 0.0f; rM[0] = rM[1] = MAX_INIT_VAL; cute::fill(rO, 0.); // Wait for Q plan.bar_q.wait((sched_meta.begin_req_idx-batch_idx)&1); CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{}); // Wait, issue WGMMA plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); if constexpr (CLUSTER_SIZE == 2) { plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); } gemm( tiled_mma_QK, thr_mma_QK.partition_fragment_A(sQ), thr_mma_QK.partition_fragment_B(sK), rP ); bar_phase_k ^= 1<(); // Calculate S = softmax(mask(scale(P))) if (block_idx != args.start_block_idx) NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks scale_softmax(rP, rS, rO, params.sm_scale_div_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup); // Store S into shared, inform warpgroup 1 save_rPb_to_sP(rS, sS, idx_in_warpgroup); fence_view_async_shared(); // Issue O += S @ V gemm( tiled_mma_PV, rS, thr_mma_PV.partition_fragment_B(sV), rO ); NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready); cute::warpgroup_wait<0>(); if constexpr (CLUSTER_SIZE == 2) { plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); } else { plan.bar_k_avail[buf_idx].arrive(); } } // Copy the next q if (threadIdx.x/32 == 0 && elect_one_sync()) { if (batch_idx != sched_meta.end_req_idx) { Tensor gQ = flat_divide( tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1), Tile, Int>{} )(_, _, head_block_idx, _0{}); launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); } else { // This kernel is followed by the combine kernel, so we signal PDL here cudaTriggerProgrammaticLaunchCompletion(); } } // Synchronize L and M across warpgroups rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); if (idx_in_warpgroup%4 == 0) { CUTE_UNROLL for (int i = 0; i < 2; ++i) { int row = get_AorC_row_idx(i, idx_in_warpgroup); sL[row] = rL[i]; sM[row] = rM[i]; } } float o_scales[2]; CUTE_UNROLL for (int i = 0; i < 2; ++i) { if (args.is_no_split) { o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i] + exp2f(rAttn_sink[i] - rM[i])); } else { o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i]); } if (idx_in_warpgroup%4 == 0) { int row = get_AorC_row_idx(i, idx_in_warpgroup); plan.sOScale[row] = o_scales[i]; } } // This is a synchronization point for warpgroup 0/1. // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free // Warpgroup 1 should wait wg 0 for sL to be ready NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); CUTE_UNROLL for (int i = 0; i < 2; ++i) rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; int start_head_idx = head_block_idx*BLOCK_M; int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M); if (args.is_no_split) { bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1) Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( Shape, Int>{}, make_stride(params.stride_o_h_q, _1{}) )); float* gSoftmaxLse = (float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + start_head_idx; // (BLOCK_M) : (1) store_o(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); int i = threadIdx.x; if (i < num_valid_seq_q) { float cur_L = sL[i]; gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E; } cute::tma_store_wait<0>(); } else { int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1) float* gSoftmaxLseAccum = (float*)params.lse_accum + split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + start_head_idx; // (BLOCK_M) : (1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout( Shape, Int>{}, make_stride(params.stride_o_accum_h_q, _1{}) )); store_o(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); int i = threadIdx.x; if (i < num_valid_seq_q) { float cur_L = sL[i]; gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i]; } cute::tma_store_wait<0>(); } sync_all_threads_in_cluster(); } } else if (warpgroup_idx == 1) { cutlass::arch::warpgroup_reg_dealloc<160>(); TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{}; ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); Tensor rO = partition_fragment_C(tiled_mma_PV, Shape, Int>{}); #pragma unroll 1 for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { MainloopArgs args = get_cur_req_info(batch_idx); cute::fill(rO, 0.); CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{}); // Wait for S and sScale NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); // Scale O float cur_scales[2]; *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2); CUTE_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); CUTE_UNROLL for (int i = 0; i < size(cur_rO); ++i) { cur_rO(i) *= cur_scales[local_row_idx]; } } // Issue O += S @ V, and wait gemm( tiled_mma_PV, thr_mma_PV.partition_fragment_A(sS), thr_mma_PV.partition_fragment_B(sV), rO ); cute::warpgroup_wait<0>(); if constexpr (CLUSTER_SIZE == 2) { plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); } else { plan.bar_k_avail[buf_idx].arrive(); } if (block_idx != args.end_block_idx-1) NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available } NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); float o_scales[2]; CUTE_UNROLL for (int i = 0; i < 2; ++i) { int row = get_AorC_row_idx(i, idx_in_warpgroup); o_scales[i] = plan.sOScale[row]; } int start_head_idx = head_block_idx*BLOCK_M; int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M); if (args.is_no_split) { bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1) Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( Shape, Int>{}, make_stride(params.stride_o_h_q, _1{}) )); store_o(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); cute::tma_store_wait<0>(); } else { int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout( Shape, Int>{}, make_stride(params.stride_o_accum_h_q, _1{}) )); store_o(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); cute::tma_store_wait<0>(); } sync_all_threads_in_cluster(); } } else { // Producer warpgroup cutlass::arch::warpgroup_reg_dealloc<152>(); static_assert(CLUSTER_SIZE == 1 || CLUSTER_SIZE == 2); static constexpr int NUM_TOKENS_PER_THREAD = CLUSTER_SIZE == 1 ? 2 : 1; static constexpr int NUM_TOKENS_PER_ROUND = 32; // If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds) int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); int lane_idx = idx_in_warpgroup % 32; int my_token_idx_base = warp_idx*8 + lane_idx%8; CUTE_NO_UNROLL for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { MainloopArgs args = get_cur_req_info(batch_idx); int* gIndices = params.indices + batch_idx*params.stride_indices_b + s_q_idx*params.stride_indices_s_q; // (topk) : (1) int* gExtraIndices = params.extra_indices + batch_idx*params.stride_extra_indices_b + s_q_idx*params.stride_extra_indices_s_q; // (extra_topk) : (1) int nxt_token_indexs[NUM_TOKENS_PER_THREAD]; CUTE_UNROLL for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) { if (MODEL_TYPE == ModelType::V32 || args.start_block_idx < args.num_orig_kv_blocks) nxt_token_indexs[round] = __ldg(gIndices + args.start_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + round*NUM_TOKENS_PER_ROUND + my_token_idx_base); } struct IsOrigBlock {}; struct IsExtraBlock {}; struct IsFirstExtraBlock {}; struct IsNotFirstExtraBlock {}; auto process_one_block = [&](int block_idx, auto is_extra_block_t, auto is_first_extra_block_t) { static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; static constexpr bool IS_FIRST_EXTRA_BLOCK = std::is_same_v; int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; int* indices_base; int page_block_size; int64_t k_block_stride, k_row_stride; fp8* k_ptr; if constexpr (!IS_EXTRA_BLOCK) { indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE; page_block_size = params.page_block_size; k_block_stride = params.stride_kv_block; k_row_stride = params.stride_kv_row; k_ptr = (fp8*)params.kv; } else { indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE; page_block_size = params.extra_page_block_size; k_block_stride = params.stride_extra_kv_block; k_row_stride = params.stride_extra_kv_row; k_ptr = (fp8*)params.extra_kv; } [[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length; [[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx; transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx])); CUTE_UNROLL for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) { int my_token_idx = my_token_idx_base + round*NUM_TOKENS_PER_ROUND; bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE; bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base); // Get prefetched token index int token_index; if constexpr (!IS_EXTRA_BLOCK) { token_index = nxt_token_indexs[round]; if (block_idx+1 != (MODEL_TYPE == ModelType::V32 ? args.end_block_idx : args.num_orig_kv_blocks)) nxt_token_indexs[round] = __ldg(gIndices + (block_idx+1)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); } else { if constexpr (IS_FIRST_EXTRA_BLOCK) { token_index = __ldg(gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); } else { token_index = nxt_token_indexs[round]; } if (block_idx+1 != args.end_block_idx) nxt_token_indexs[round] = __ldg(gExtraIndices + (block_idx+1-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); } if constexpr (MODEL_TYPE == ModelType::MODEL1) { // For MODEL1, we need to check whether the token_index is within topk_length if (rel_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx >= topk_length) { token_index = -1; // To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length } } int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error fp8* gK_base; bf16 scales[NUM_SCALES]; if constexpr (MODEL_TYPE == ModelType::V32) { static_assert(NUM_SCALES == 4); gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*k_row_stride; float scales_float[NUM_SCALES]; *(float4*)(scales_float) = load_128b_from_gmem((float*)(gK_base+HEAD_DIM_NOPE)); CUTE_UNROLL for (int i = 0; i < NUM_SCALES; ++i) { scales[i] = (bf16)scales_float[i]; } } else { static_assert(NUM_SCALES == 8); gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*sizeof(bf16)); fp8_e8m0* gK_scales_base = (fp8_e8m0*)(k_ptr + block_index*k_block_stride + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*sizeof(bf16)) + rel_idx_in_block*NUM_SCALES*sizeof(fp8_e8m0)); fp8_e8m0 scales_e8m0[NUM_SCALES]; *(int64_t*)scales_e8m0 = __ldg((int64_t*)gK_scales_base); CUTE_UNROLL for (int i = 0; i < NUM_SCALES; i += 2) { *(__nv_bfloat162_raw*)(scales+i) = __nv_cvt_e8m0x2_to_bf162raw(*(__nv_fp8x2_storage_t*)(scales_e8m0+i)); } } // Wait for the nope buffer to be available if (round == 0) { plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1); } if (CLUSTER_SIZE == 2 && round == 0 && idx_in_warpgroup == 0) { plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16)); } // Collectively copy from global memory and dequant // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py fp8* gK_nope = gK_base + (lane_idx/8)*16; if (token_index == -1) { CUTE_UNROLL for (int i = 0; i < NUM_SCALES; ++i) scales[i] = (bf16)0.0f; } CUTE_UNROLL for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) { fp8x16 cur_fp8x16 = load_128b_from_gmem(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1) bf16 scale = scales[MODEL_TYPE == ModelType::V32 ? dim_idx/2 : dim_idx]; auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) { int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE; bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, __bfloat162bfloat162(*(__nv_bfloat16*)(&scale))); *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; if constexpr (CLUSTER_SIZE == 2) { st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); } }; if (token_index == -1) *(uint128_t*)(&cur_fp8x16) = uint128_t(); dequant_and_save_bf16x8(cur_fp8x16.lo, 0); dequant_and_save_bf16x8(cur_fp8x16.hi, 8); } bf16* gK_rope; if constexpr (MODEL_TYPE == ModelType::V32) { gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8; } else { gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE) + (lane_idx/8)*8; } bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE; bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base); CUTE_UNROLL for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) { bf16x8 cur_bf16x8 = load_128b_from_gmem(gK_rope + dim_idx*32); if constexpr (MODEL_TYPE == ModelType::V32) { // NOTE We do not need to mask the RoPE part for V3.2 since it isn't involved in the SV gemm } else { if (token_index == -1) *(uint128_t*)(&cur_bf16x8) = uint128_t(); } int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE; *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; if constexpr (CLUSTER_SIZE == 2) { st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); } } } fence_view_async_shared(); if (idx_in_warpgroup < 32) { // We put this after fence_view_async_shared() since this won't be read by async proxy auto is_index_valid = [&](int index, int offset_within_thread) -> bool { if constexpr (MODEL_TYPE == ModelType::V32) { return index != -1; } else { return index != -1 && rel_block_idx*TOPK_BLOCK_SIZE + lane_idx*2 + offset_within_thread < topk_length; } }; int2 indices = __ldg((int2*)(indices_base + lane_idx*2)); *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = { is_index_valid(indices.x, 0), is_index_valid(indices.y, 1) }; } // Signal the barrier plan.bar_k_local_ready[buf_idx].arrive(); bar_phase_k ^= 1 << buf_idx; }; if constexpr (MODEL_TYPE == ModelType::V32) { CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{}); } } else { CUTE_NO_UNROLL for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{}); } if (args.num_orig_kv_blocks < args.end_block_idx) { process_one_block(max(args.start_block_idx, args.num_orig_kv_blocks), IsExtraBlock{}, IsFirstExtraBlock{}); } CUTE_NO_UNROLL for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks)+1; block_idx < args.end_block_idx; ++block_idx) { process_one_block(block_idx, IsExtraBlock{}, IsNotFirstExtraBlock{}); } } sync_all_threads_in_cluster(); } } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); } #endif } template __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, Kernel::CLUSTER_SIZE) flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TMAParams tma_params) { Kernel::devfunc(params, tma_params); } template void KernelTemplate::run(const SparseAttnDecodeParams ¶ms) { KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); KU_ASSERT(params.d_qk == HEAD_DIM_K); KU_ASSERT(params.d_v == HEAD_DIM_V); KU_ASSERT(params.h_q % BLOCK_M == 0); if constexpr (MODEL_TYPE == ModelType::MODEL1) { constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8; KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous if (params.extra_kv != nullptr) { KU_ASSERT(params.stride_extra_kv_row == BYTES_PER_TOKEN, "Each page block in extra KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous } } else { KU_ASSERT(params.extra_kv == nullptr, "V3.2 does not support extra KV cache"); KU_ASSERT(params.topk_length == nullptr, "V3.2 does not support dynamic topk length"); KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16) } auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q, params.b); auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((bf16*)params.q), make_layout( shape_Q, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b) ) ), SmemLayoutQ{} ); CUtensorMap tensor_map_o; { // Here we manually construct TMA descriptor to store O, in order to leverage 5D TMA uint64_t size[5] = {OBUF_SW, (unsigned long)params.h_q, HEAD_DIM_V/OBUF_SW, (unsigned long)params.s_q, (unsigned long)params.b}; uint64_t stride[4] = {params.stride_o_h_q*sizeof(bf16), OBUF_SW*sizeof(bf16), params.stride_o_s_q*sizeof(bf16), params.stride_o_b*sizeof(bf16)}; uint32_t box_size[5] = {OBUF_SW, BLOCK_M, HEAD_DIM_V/OBUF_SW, 1, 1}; uint32_t elem_stride[5] = {1, 1, 1, 1, 1}; CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tensor_map_o, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, params.out, size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, OBUF_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B : OBUF_SW == 32 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : OBUF_SW == 16 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); KU_ASSERT(res == CUresult::CUDA_SUCCESS); } TmaParams< decltype(shape_Q), decltype(tma_Q) > tma_params = { shape_Q, tma_Q, tensor_map_o }; auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel, decltype(tma_params)>; constexpr size_t smem_size = sizeof(SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // NOTE Don't use PDL because of potential compiler bugs! // cudaLaunchAttribute mla_kernel_attributes[1]; // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; // cudaLaunchConfig_t mla_kernel_config = { // dim3(num_m_block, params.h_k, params.num_sm_parts), // dim3(NUM_THREADS, 1, 1), // smem_size, // stream, // mla_kernel_attributes, // 1 // }; // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); cutlass::ClusterLaunchParams launch_params = { dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts), dim3(NUM_THREADS, 1, 1), dim3(CLUSTER_SIZE, 1, 1), smem_size, params.stream }; cutlass::launch_kernel_on_cluster( launch_params, (void*)mla_kernel, params, tma_params ); KU_CHECK_KERNEL_LAUNCH(); } template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms) { KernelTemplate::run(params); } } ================================================ FILE: csrc/sm90/decode/sparse_fp8/splitkv_mla.h ================================================ #pragma once #include "params.h" namespace sm90::decode::sparse_fp8 { template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } ================================================ FILE: csrc/sm90/helpers.h ================================================ #pragma once #include #include namespace sm90 { __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" :: "r"(dst_addr), "l"(src), "n"(16)); } __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n" :: "r"(dst_addr), "l"(src), "r"(pred?16:0), "l"(cache_policy)); } __forceinline__ __device__ int64_t createpolicy_evict_last() { int64_t res; asm volatile( "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" : "=l"(res) : ); return res; } __forceinline__ __device__ int64_t createpolicy_evict_first() { int64_t res; asm volatile( "createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t" : "=l"(res) : ); return res; } __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; } __forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); return col_idx; } // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h // * Copyright (c) 2024, Tri Dao. template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { using namespace cute; constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } // A simpler version of gemm template __forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sA_frag = thr_mma.partition_fragment_A(sA); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(sA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(sA_frag); ++k) { cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); } template __forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); Tensor sB_frag = thr_mma.partition_fragment_B(sB); static_assert(size<2>(rA_frag) == size<2>(sB_frag)); warpgroup_fence_operand(const_cast(rA_frag)); warpgroup_fence_operand(rC_frag); warpgroup_arrive(); tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(rA_frag); ++k) { cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_fence_operand(rC_frag); warpgroup_fence_operand(const_cast(rA_frag)); } __forceinline__ __device__ uint32_t get_sm_id() { uint32_t ret; asm("mov.u32 %0, %%smid;" : "=r"(ret)); return ret; } static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs. template CUTE_DEVICE T* get_peer_addr(const T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); } template< typename TMA, typename Tensor0, typename Tensor1 > CUTE_DEVICE void launch_tma_copy( const TMA &tma_copy, Tensor0 src, Tensor1 dst, cutlass::arch::ClusterTransactionBarrier &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL ) { auto thr_tma = tma_copy.get_slice(cute::_0{}); cute::copy( tma_copy.with(reinterpret_cast(bar), 0, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); } } ================================================ FILE: csrc/sm90/prefill/sparse/config.h ================================================ #pragma once #include #include #include #include #include #include #include #include "defines.h" #include "params.h" namespace sm90::fwd { using namespace cute; template class KernelTemplate { public: static constexpr int D_Q = D_QK; static constexpr int D_K = D_QK; static constexpr int D_V = 512; static constexpr int B_H = 64; static constexpr int B_TOPK = 64; // TopK block size static constexpr int NUM_THREADS = 128*3; static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) template using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( GMMA::Layout_SW128_Atom{}, Shape, Int<64*NUM_TILES>>{}, Step<_1, _2>{} ), Shape<_1, _1>{})); template using SmemLayoutKTilesTransposed = decltype(composition( SmemLayoutKTiles{}, Layout, Int>, Stride, _1>>{} )); using SmemLayoutQ = SmemLayoutQTiles; using SmemLayoutO = SmemLayoutOTiles; using SmemLayoutK = SmemLayoutKTiles; using SmemLayoutV = SmemLayoutKTilesTransposed; using SmemLayoutHalfV = SmemLayoutKTilesTransposed; using SmemLayoutS = decltype(coalesce(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} ), Shape<_1, _1>{})); struct SharedMemoryPlan { union { array_aligned> q; array_aligned> o; } q_o; array_aligned> k[2]; array_aligned> s[D_QK == 576 ? 1 : 2]; // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers bool is_kv_valid[2][B_TOPK]; float2 sM[32]; float2 sL[64]; // For reduction across WG0/1 in epilogue float final_max_logits[64], final_lse[64]; transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; }; using TiledMMA_QK = decltype(make_tiled_mma( GMMA::MMA_64x64x16_F32BF16BF16_SS{}, Layout>{} )); using TiledMMA_PV_LocalP = decltype(make_tiled_mma( GMMA::MMA_64x256x16_F32BF16BF16_RS{}, Layout>{} )); using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( GMMA::MMA_64x256x16_F32BF16BF16_SS{}, Layout>{} )); template< typename Shape_Q, typename TMA_Q > struct TmaParams { Shape_Q shape_Q; TMA_Q tma_Q; CUtensorMap tensor_map_O; }; enum NamedBarriers : uint32_t { wg0_bunch_0_ready = 0, wg1_bunch_0_ready = 1, wg0_s0_ready = 2, wg1_s1_ready = 3, sL_ready = 4, warpgroup0_sync = 5, warpgroup1_sync = 6, epilogue_sync = 7 }; // Save rPb (64x64, bfloat16) to sP using the stmatrix instruction template< typename Tensor0, typename Tensor1 > static __forceinline__ __device__ void save_rS_to_sS( Tensor0 const &rPb, Tensor1 const &sP, int idx_in_warpgroup ) { auto r2s_copy = make_tiled_copy_C( Copy_Atom{}, TiledMMA_QK{} ); ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); Tensor thr_copy_rPb = thr_copy.retile_S(rPb); Tensor thr_copy_sP = thr_copy.partition_D(sP); cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); } template static __device__ __forceinline__ void devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params); static void run(const SparseAttnFwdParams ¶ms); }; }; ================================================ FILE: csrc/sm90/prefill/sparse/fwd.cu ================================================ #include "fwd.h" #include #include "phase1.h" namespace sm90 { void run_fwd_kernel(const SparseAttnFwdParams& params) { const bool have_topk_length = params.topk_length != nullptr; // Dispatch based on d_qk dimension and presence of topk_length if (params.d_qk == 512) { if (have_topk_length) { sm90::fwd::run_fwd_phase1_kernel<512, true>(params); } else { sm90::fwd::run_fwd_phase1_kernel<512, false>(params); } } else if (params.d_qk == 576) { if (have_topk_length) { sm90::fwd::run_fwd_phase1_kernel<576, true>(params); } else { sm90::fwd::run_fwd_phase1_kernel<576, false>(params); } } else { throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel"); } } } // namespace sm90 ================================================ FILE: csrc/sm90/prefill/sparse/fwd.h ================================================ #pragma once #include "params.h" namespace sm90 { void run_fwd_kernel(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm90::fwd { // NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH // = true / false respectively, to compile them in parallel. template void run_fwd_phase1_kernel<512, false>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm90::fwd { // NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH // = true / false respectively, to compile them in parallel. template void run_fwd_phase1_kernel<512, true>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm90::fwd { template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu ================================================ #include "../phase1.h" #include "../phase1.cuh" namespace sm90::fwd { template void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/sm90/prefill/sparse/phase1.cuh ================================================ #pragma once #include "config.h" #include "utils.h" #include "../../helpers.h" namespace sm90::fwd { using namespace cute; CUTE_DEVICE void st_global_cs_128(float f0, float f1, float f2, float f3, void *dst_ptr) { asm volatile("st.weak.global.cs.v4.f32 [%0], {%1, %2, %3, %4};\n" : : "l"(dst_ptr), "f"(f0), "f"(f1), "f"(f2), "f"(f3) ); } CUTE_DEVICE float2 __shfl_xor_sync_float2( uint32_t mask, float2 value, int offset ) { float2 res; *reinterpret_cast(&res) = __shfl_xor_sync( mask, *reinterpret_cast(&value), offset ); return res; } CUTE_DEVICE void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr); asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" : : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) : "memory"); } template template __device__ void KernelTemplate::devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int q_h_idx = blockIdx.x % (params.h_q/B_H); const int s_q_idx = blockIdx.x / (params.h_q/B_H); const int warpgroup_idx = cutlass::canonical_warp_group_idx(); const int warp_idx = cutlass::canonical_warp_idx_sync(); const int idx_in_warpgroup = threadIdx.x % 128; // Define shared tensors extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{}); Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{}); Tensor sS0 = make_tensor(make_smem_ptr(D_QK == 576 ? plan.k[0].data()+64*512 : plan.s[1].data()), SmemLayoutS{}); // Overlap with sK0's RoPE part for V3.2 Tensor sS1 = make_tensor(make_smem_ptr(plan.s[0].data()), SmemLayoutS{}); if (warp_idx == 0 && elect_one_sync()) { // Prefetch TMA descriptors cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(&tma_params.tensor_map_O); // Initialize barriers plan.bar_q.init(1); CUTE_UNROLL for (int i = 0; i < 2; ++i) { plan.bar_k0_free[i].init(128); plan.bar_k0_ready[i].init(128); plan.bar_k1_free[i].init(128); plan.bar_k1_ready[i].init(128); } plan.bar_is_kv_valid_ready.init(16); fence_barrier_init(); } __syncthreads(); const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk; const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK); if (warpgroup_idx == 0 || warpgroup_idx == 1) { cutlass::arch::warpgroup_reg_alloc<216>(); if (warp_idx == 0 && elect_one_sync()) { // Load Q Tensor gQ = flat_divide( tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), Tile, Int>{} )(_, _, q_h_idx, _0{}); launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); } float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation float rL[2] = {0.0f, 0.0f}; Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); cute::fill(rO, 0.0f); // Wait for Q plan.bar_q.wait(0); bool cur_bar_wait_phase = 0; struct Warpgroup0 {}; struct Warpgroup1 {}; auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) { constexpr bool IS_WG1 = std::is_same_v; TiledMMA tiled_mma_QK = TiledMMA_QK{}; Tensor sQ_tile = flat_divide(sQ, Tile, Int<64>>{})(_, _, _0{}, tile_idx); Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{}); gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup); }; auto mask_rP = [&](auto warpgroup_idx) { constexpr bool IS_WG1 = std::is_same_v; plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); CUTE_UNROLL for (int row_idx = 0; row_idx < 2; ++row_idx) { CUTE_UNROLL for (int i = row_idx*2; i < size(rP); i += 4) { int col = 8*(i/4) + (idx_in_warpgroup%4)*2; if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY; if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY; } } }; auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) { plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); constexpr bool IS_WG1 = std::is_same_v; const float scale = params.sm_scale_div_log2; float r_sM[2]; if constexpr (IS_WG1) { *(float2*)r_sM = plan.sM[idx_in_warpgroup/4]; } float new_maxs[2]; CUTE_UNROLL for (int row_idx = 0; row_idx < 2; ++row_idx) { // Get rowwise max float cur_max = -INFINITY; CUTE_UNROLL for (int i = row_idx*2; i < size(rP); i += 4) { cur_max = max(cur_max, max(rP(i), rP(i+1))); } cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); cur_max *= scale; // Get new max and scale // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round) new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max); // Scale O float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]); CUTE_UNROLL for (int i = row_idx*2; i < size(rO); i += 4) { rO(i) *= scale_for_o; rO(i+1) *= scale_for_o; } // Get rS float cur_sum = 0; CUTE_UNROLL for (int i = row_idx*2; i < size(rP); i += 4) { rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]); rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]); rS(i) = (bf16)rP(i); rS(i+1) = (bf16)rP(i+1); cur_sum += rP(i) + rP(i+1); } rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum; } __syncwarp(); if (idx_in_warpgroup%4 == 0) { plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs; } rM[0] = new_maxs[0]; rM[1] = new_maxs[1]; }; auto reduce_L = [&]() { // Reduce L // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131 rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); if (idx_in_warpgroup%4 == 0) plan.sL[threadIdx.x/4] = *(float2*)(rL); NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready); float2 peer_L = plan.sL[(threadIdx.x/4)^32]; rL[0] += peer_L.x; rL[1] += peer_L.y; }; auto store_O = [&]() { float scale_factors[2]; CUTE_UNROLL for (int i = 0; i < 2; ++i) { float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : params.attn_sink[q_h_idx*B_H + get_AorC_row_idx(i, idx_in_warpgroup)]*CUDART_L2E_F; scale_factors[i] = 1.0f / (rL[i] + exp2f(attn_sink - rM[i])); if (rL[i] == 0.0f) scale_factors[i] = 0.0f; // The output should be 0 whatever attn_sink is } Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{}); bf16* stsm_addrs[4]; int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16); CUTE_UNROLL for (int i = 0; i < 64/16; ++i) { stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i); } bool s2g_pred = warp_idx%4 == 0 && elect_one_sync(); warpgroup_wait<0>(); CUTE_UNROLL for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) { // Convert constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size bf16 cur_rOb[NUM_ELEMS_EACH_TILE]; CUTE_UNROLL for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) { cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]); } // R -> S CUTE_UNROLL for (int i = 0; i < 64/16; ++i) { SM90_U32x4_STSM_N::copy( *reinterpret_cast(cur_rOb + i*8 + 0), *reinterpret_cast(cur_rOb + i*8 + 2), *reinterpret_cast(cur_rOb + i*8 + 4), *reinterpret_cast(cur_rOb + i*8 + 6), *reinterpret_cast(stsm_addrs[i] + tile_idx*(B_H*64)) ); } fence_view_async_shared(); NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync); // S -> G if (s2g_pred) { int g_tile_idx = warpgroup_idx*4 + tile_idx; SM90_TMA_STORE_3D::copy( &tma_params.tensor_map_O, plan.q_o.o.data() + g_tile_idx*(B_H*64), g_tile_idx*64, q_h_idx*B_H, s_q_idx ); } } cute::tma_store_arrive(); }; if (warpgroup_idx == 0) { // Warpgroup 0 auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) { plan.bar_k0_ready[0].wait(cur_bar_wait_phase); qkt_gemm_one_tile(Warpgroup0{}, 0, true); qkt_gemm_one_tile(Warpgroup0{}, 1, false); qkt_gemm_one_tile(Warpgroup0{}, 2, false); qkt_gemm_one_tile(Warpgroup0{}, 3, false); warpgroup_commit_batch(); }; auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) { plan.bar_k0_ready[1].wait(cur_bar_wait_phase); qkt_gemm_one_tile(Warpgroup0{}, 4, false); qkt_gemm_one_tile(Warpgroup0{}, 5, false); qkt_gemm_one_tile(Warpgroup0{}, 6, false); qkt_gemm_one_tile(Warpgroup0{}, 7, false); if constexpr (D_QK == 576) { qkt_gemm_one_tile(Warpgroup0{}, 8, false); } warpgroup_commit_batch(); }; auto scale_rS = [&](float scales[2]) { CUTE_UNROLL for (int row = 0; row < 2; ++row) { CUTE_UNROLL for (int i = row*2; i < size(rP); i += 4) { rS(i) = (bf16)(rP(i) * scales[row]); rS(i+1) = (bf16)(rP(i+1) * scales[row]); } } }; auto rescale_rO = [&](float scales[2]) { CUTE_UNROLL for (int row = 0; row < 2; ++row) { CUTE_UNROLL for (int i = row*2; i < size(rO); i += 4) { rO(i) *= scales[row]; rO(i+1) *= scales[row]; } rL[row] *= scales[row]; } }; CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); if (block_idx == 0) { // NOTE: We put this code here to avoid register spilling pipelined_wait_and_qkt_gemm_l(); pipelined_wait_and_qkt_gemm_r(); warpgroup_wait<0>(); } // Online softmax, inform WG1 mask_rP(Warpgroup0{}); online_softmax_and_rescale_o(Warpgroup0{}); NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready); // Issue rO0 += rS0 @ sV0l gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup); warpgroup_commit_batch(); // Mark V0L as free warpgroup_wait<0>(); plan.bar_k0_free[0].arrive(); // Wait for new sM, scale rS, save, inform WG1 NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); float new_rM[2], scale_factors[2]; *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; CUTE_UNROLL for (int i = 0; i < 2; ++i) { scale_factors[i] = exp2f(rM[i] - new_rM[i]); rM[i] = new_rM[i]; } scale_rS(scale_factors); save_rS_to_sS(rS, sS0, idx_in_warpgroup); fence_view_async_shared(); NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); // Wait for sS1 NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready); // Rescale rO0, Issue rO0 += sS1 @ sV1L rescale_rO(scale_factors); gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup); warpgroup_commit_batch(); cur_bar_wait_phase ^= 1; if (block_idx+2 < num_topk_blocks) { // Launch the next QK^T GEMM pipelined_wait_and_qkt_gemm_l(); // Mark V1L as free warpgroup_wait<1>(); plan.bar_k1_free[0].arrive(); pipelined_wait_and_qkt_gemm_r(); // Wait for rP0 = sQ @ sK0 warpgroup_wait<0>(); } else { // Mark V1L as free warpgroup_wait<0>(); plan.bar_k1_free[0].arrive(); } } reduce_L(); store_O(); } else { // Warpgroup 1 auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) { plan.bar_k1_ready[1].wait(cur_bar_wait_phase); qkt_gemm_one_tile(Warpgroup1{}, 4, true); qkt_gemm_one_tile(Warpgroup1{}, 5, false); qkt_gemm_one_tile(Warpgroup1{}, 6, false); qkt_gemm_one_tile(Warpgroup1{}, 7, false); if constexpr (D_QK == 576) { qkt_gemm_one_tile(Warpgroup1{}, 8, false); } plan.bar_k1_ready[0].wait(cur_bar_wait_phase); qkt_gemm_one_tile(Warpgroup1{}, 0, false); qkt_gemm_one_tile(Warpgroup1{}, 1, false); qkt_gemm_one_tile(Warpgroup1{}, 2, false); qkt_gemm_one_tile(Warpgroup1{}, 3, false); warpgroup_commit_batch(); }; CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); // Issue rP1 = sQ @ sK1, and wait pipelined_wait_and_qkt_gemm(); warpgroup_wait<0>(); mask_rP(Warpgroup1{}); // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready) NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); online_softmax_and_rescale_o(Warpgroup1{}); NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready); // Issue rO1 += rS1 @ sV1R gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup); warpgroup_commit_batch(); // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); warpgroup_commit_batch(); // Save rS1, inform WG0 fence_view_async_shared(); NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready); // Wait for GEMM, and inform that sV1R is free warpgroup_wait<1>(); plan.bar_k1_free[1].arrive(); // Wait for GEMM, and inform that sV0R is free warpgroup_wait<0>(); plan.bar_k0_free[1].arrive(); cur_bar_wait_phase ^= 1; } reduce_L(); store_O(); // Save lse if (idx_in_warpgroup%4 == 0) { for (int row = 0; row < 2; ++row) { int real_row = get_AorC_row_idx(row, idx_in_warpgroup); bool is_no_valid_tokens = rL[row] == 0.0f; plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]*CUDART_LN2_F; plan.final_lse[real_row] = is_no_valid_tokens ? +INFINITY : logf(rL[row]) + rM[row]*CUDART_LN2_F; } fence_view_async_shared(); } NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync); if (idx_in_warpgroup == 0) { int g_offset = s_q_idx*params.h_q + q_h_idx*B_H; SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float)); SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float)); cute::tma_store_arrive(); } } } else { // Producer warpgroup cutlass::arch::warpgroup_reg_dealloc<72>(); constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE; constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; int idx_in_group = idx_in_warpgroup % GROUP_SIZE; int group_idx = idx_in_warpgroup / GROUP_SIZE; int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8)); bf16* my_gKV_base = params.kv + idx_in_group*8; int64_t token_indices[2][NUM_ROWS_PER_GROUP]; bool is_token_valid[2][NUM_ROWS_PER_GROUP]; auto load_token_indices = [&](int block_idx) { CUTE_UNROLL for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; int t = __ldg(gIndices + offs); token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster bool is_cur_token_valid = t >= 0 && t < params.s_kv; if constexpr (HAVE_TOPK_LENGTH) { is_cur_token_valid &= offs < topk_length; } is_token_valid[buf_idx][local_row] = is_cur_token_valid; } } }; int64_t cache_policy = createpolicy_evict_last(); auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { // Copy some K/V tiles from global memory to shared memory // A tile has a shape of 64 (B_TOPK) x 64 // `buf_idx` is the index of the shared memory buffer, 0 or 1 // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8 CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { int64_t token_index = token_indices[buf_idx][local_row]; CUTE_UNROLL for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) { cp_async_cacheglobal_l2_prefetch_256B( my_gKV_base + token_index + tile_idx*64, my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64), is_token_valid[buf_idx][local_row], cache_policy ); } } }; auto commit_to_mbar = [&](transac_bar_t &bar) { cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar)); }; int cur_bar_wait_phase = 1; CUTE_NO_UNROLL for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { load_token_indices(block_idx); // V0L plan.bar_k0_free[0].wait(cur_bar_wait_phase); copy_tiles(block_idx+0, 0, 0, 4); commit_to_mbar(plan.bar_k0_ready[0]); // V1R plan.bar_k1_free[1].wait(cur_bar_wait_phase); copy_tiles(block_idx+1, 1, 4, D_K/64); commit_to_mbar(plan.bar_k1_ready[1]); // V0R plan.bar_k0_free[1].wait(cur_bar_wait_phase); copy_tiles(block_idx+0, 0, 4, D_K/64); commit_to_mbar(plan.bar_k0_ready[1]); // V1L plan.bar_k1_free[0].wait(cur_bar_wait_phase); copy_tiles(block_idx+1, 1, 0, 4); commit_to_mbar(plan.bar_k1_ready[0]); // Valid mask // NOTE: V1R's finish implies maskings of the last round have finished if (idx_in_group == 0) { CUTE_UNROLL for (int buf_idx = 0; buf_idx < 2; ++buf_idx) CUTE_UNROLL for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; plan.bar_is_kv_valid_ready.arrive(); } cur_bar_wait_phase ^= 1; } } #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); } #endif } template __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1) sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TMAParams tma_params) { Kernel::devfunc(params, tma_params); } template void KernelTemplate::run(const SparseAttnFwdParams ¶ms) { KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk > 0); KU_ASSERT(params.h_q % B_H == 0); auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, make_tensor( make_gmem_ptr((bf16*)params.q), make_layout( shape_Q, make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) ) ), SmemLayoutQ{} ); CUtensorMap tensor_map_O; { uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q}; uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)}; uint32_t box_size[3] = {64, B_H, 1}; uint32_t elem_stride[3] = {1, 1, 1}; CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tensor_map_O, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 3, params.out, size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); KU_ASSERT(res == CUresult::CUDA_SUCCESS); } TmaParams< decltype(shape_Q), decltype(tma_Q) > tma_params = { shape_Q, tma_Q, tensor_map_O }; auto kernel = &sparse_attn_fwd_kernel, decltype(tma_params)>; constexpr size_t smem_size = sizeof(SharedMemoryPlan); KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); cutlass::ClusterLaunchParams launch_params = { dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z) dim3(NUM_THREADS, 1, 1), dim3(1, 1, 1), smem_size, params.stream }; cutlass::launch_kernel_on_cluster( launch_params, (void*)kernel, params, tma_params ); KU_CHECK_KERNEL_LAUNCH(); } template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { KernelTemplate::run(params); } } ================================================ FILE: csrc/sm90/prefill/sparse/phase1.h ================================================ #pragma once #include "../../../params.h" namespace sm90::fwd { template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); } ================================================ FILE: csrc/smxx/decode/combine/combine.cu ================================================ #include "combine.h" #include #include #include #include #include #include #include "params.h" #include "utils.h" using namespace cute; namespace smxx::decode { template __global__ void __launch_bounds__(NUM_THREADS) flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) { // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m const int batch_idx = blockIdx.x; const int s_q_idx = blockIdx.y; const int h_block_idx = blockIdx.z; const int warp_idx = threadIdx.x / 32; const int lane_idx = threadIdx.x % 32; int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx); if (warp_idx >= num_valid_heads) { return; } const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int my_num_splits = end_split_idx - start_split_idx; if (my_num_splits == 1) { return; } FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); Tensor gLseAccum = make_tensor( make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), Shape, Int>{}, make_stride(params.stride_lse_accum_split, _1{}) ); Tensor gLse = make_tensor( make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M), Shape>{}, Stride<_1>{} ); __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; // Wait for the previous kernel (the MLA kernel) to finish cudaGridDependencySynchronize(); // Prefetch static_assert(HEAD_DIM_V % (32*4) == 0); constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4); float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; float4 datas[ELEMS_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) { datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL } // Warp #i gathers LseAccum for seq #i { constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); float local_lse[NUM_LSE_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*32 + lane_idx; local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY; } float max_lse = -INFINITY; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) max_lse = max(max_lse, local_lse[i]); CUTLASS_PRAGMA_UNROLL for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf float sum_lse = 0; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); CUTLASS_PRAGMA_UNROLL for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse; if (lane_idx == 0) gLse(warp_idx) = global_lse / (float)M_LOG2E; if (params.attn_sink != nullptr) { int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; float attn_sink = __ldg(params.attn_sink + q_head_idx); if (global_lse != INFINITY) { // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) // If attn_sink is -inf, this has no effect on global_lse global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse)); } else { // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf) global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F; } } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*32 + lane_idx; smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse); } } __syncwarp(); // Warp #i accumulates activation for seq #i { float4 result[ELEMS_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) result[i] = {0.0f, 0.0f, 0.0f, 0.0f}; #pragma unroll 1 for (int split = 0; split < my_num_splits; ++split) { float lse_scale = smem_buf[warp_idx][split]; // if (lse_scale != 0.f) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) { result[i].x += lse_scale * datas[i].x; result[i].y += lse_scale * datas[i].y; result[i].z += lse_scale * datas[i].z; result[i].w += lse_scale * datas[i].w; if (split != my_num_splits-1) { datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128); } } // } } const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) { float4 data = result[i]; ElementT data_converted[4]; data_converted[0] = (ElementT)(data.x); data_converted[1] = (ElementT)(data.y); data_converted[2] = (ElementT)(data.z); data_converted[3] = (ElementT)(data.w); static_assert(sizeof(ElementT) == 2); *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted; } } } #define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ [&] { \ if (NUM_SPLITS <= 32) { \ constexpr static int NAME = 32; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 64) { \ constexpr static int NAME = 64; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 96) { \ constexpr static int NAME = 96; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 128) { \ constexpr static int NAME = 128; \ return __VA_ARGS__(); \ } else if (NUM_SPLITS <= 160) { \ constexpr static int NAME = 160; \ return __VA_ARGS__(); \ } else { \ FLASH_ASSERT(false); \ } \ }() template void run_flash_mla_combine_kernel(CombineParams ¶ms) { static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA FLASH_ASSERT(params.d_v == HEAD_DIM_V); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { constexpr int BLOCK_SIZE_M = 8; constexpr int NUM_THREADS = BLOCK_SIZE_M*32; constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); auto combine_kernel = &flash_fwd_mla_combine_kernel; CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) cudaLaunchAttribute attribute[1]; attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; cudaLaunchConfig_t combine_kernel_config = { dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)), dim3(NUM_THREADS, 1, 1), 0, params.stream, attribute, 1 }; CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params)); }); CHECK_CUDA_KERNEL_LAUNCH(); } template void run_flash_mla_combine_kernel(CombineParams ¶ms); #ifndef FLASH_MLA_DISABLE_FP16 template void run_flash_mla_combine_kernel(CombineParams ¶ms); #endif } ================================================ FILE: csrc/smxx/decode/combine/combine.h ================================================ #pragma once #include "params.h" namespace smxx::decode { template void run_flash_mla_combine_kernel(CombineParams ¶ms); } ================================================ FILE: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu ================================================ #include "get_decoding_sched_meta.h" #include #include #include #include "utils.h" namespace smxx::decode { __global__ void __launch_bounds__(32, 1, 1) get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; int *num_splits_ptr = params.num_splits_ptr; int batch_size = params.b; int block_size_n = params.block_size_n; int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; int num_sm_parts = params.num_sm_parts; extern __shared__ int shared_mem[]; int* num_blocks_shared = shared_mem; // [batch_size] int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size] int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size] int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size] int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { int cur_s_k; if (params.topk == -1) { // Dense model, cur_s_k = actual s_k cur_s_k = __ldg(seqlens_k_ptr + i); } else { // Sparse model, cur_s_k = topk (+ extra topk) cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk; if (cur_s_k == 0) cur_s_k = 1; // Ensure the main loop will never be empty if (params.extra_topk) { cur_s_k = ku::ceil(cur_s_k, block_size_n); cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk; } } seqlens_k_shared[i] = cur_s_k; int first_token_idx = 0; int last_token_idx = max(cur_s_k-1, 0); int cur_first_block_idx = first_token_idx / block_size_n; int cur_last_block_idx = last_token_idx / block_size_n; // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx] // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel. int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; total_num_blocks += num_blocks + fixed_overhead_num_blocks; num_blocks_shared[i] = num_blocks; first_block_idx_shared[i] = cur_first_block_idx; last_block_idx_shared[i] = cur_last_block_idx; } for (int offset = 16; offset >= 1; offset /= 2) { total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); } __syncwarp(); if (threadIdx.x == 0) { int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { DecodingSchedMeta cur_meta; cur_meta.begin_req_idx = now_req_idx; cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx]; cur_meta.begin_split_idx = now_n_split_idx; cur_meta.is_first_req_splitted = (now_block != 0); int remain_payload = payload; while (now_req_idx < batch_size) { int num_blocks = num_blocks_shared[now_req_idx]; int now_remain_blocks = num_blocks - now_block; if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { cum_num_splits += now_n_split_idx + 1; num_splits_shared[now_req_idx + 1] = cum_num_splits; remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; ++now_req_idx; now_block = 0; now_n_split_idx = 0; } else { if (remain_payload - fixed_overhead_num_blocks > 0) { now_block += remain_payload - fixed_overhead_num_blocks; ++now_n_split_idx; remain_payload = 0; } break; } } cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1; cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1); cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0; if (cur_meta.begin_req_idx == cur_meta.end_req_idx) { cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted; } tile_scheduler_metadata_ptr[i] = cur_meta; } FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0); } __syncwarp(); for (int i = threadIdx.x; i <= batch_size; i += 32) { num_splits_ptr[i] = num_splits_shared[i]; } } void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams ¶ms) { int smem_size = sizeof(int) * (params.b*5+1); CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); } } ================================================ FILE: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h ================================================ #pragma once #include "params.h" namespace smxx::decode { void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams ¶ms); } ================================================ FILE: csrc/utils.h ================================================ #pragma once #include #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while(0) #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) #define FLASH_ASSERT(cond) \ do { \ if (not (cond)) { \ fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ exit(1); \ } \ } while(0) #define FLASH_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) { \ printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ asm("trap;"); \ } \ } while(0) #define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } template __inline__ __host__ __device__ T ceil_div(const T &a, const T &b) { return (a + b - 1) / b; } #ifndef TRAP_ONLY_DEVICE_ASSERT #define TRAP_ONLY_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) \ asm("trap;"); \ } while (0) #endif #ifndef TRAP_ONLY_DEVICE_ASSERT #define TRAP_ONLY_DEVICE_ASSERT(cond) \ do { \ if (not (cond)) \ asm("trap;"); \ } while (0) #endif struct RingBufferState { uint32_t cur_block_idx = 0u; __device__ __forceinline__ void update() { cur_block_idx += 1; } template __device__ __forceinline__ std::pair get() const { uint32_t stage_idx = cur_block_idx % NUM_STAGES; bool phase = (cur_block_idx / NUM_STAGES) & 1; return {stage_idx, phase}; } __device__ __forceinline__ RingBufferState offset_by(const int offset) const { // Must guarantee no underflow uint32_t new_block_idx = static_cast(static_cast(cur_block_idx) + offset); RingBufferState new_state; new_state.cur_block_idx = new_block_idx; return new_state; } }; ================================================ FILE: docs/20250422-new-kernel-deep-dive.md ================================================ # A Deep-Dive Into the New Flash MLA Kernel In the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) of the Flash MLA kernel, we have achieved impressive performance: 3000 GB/s in memory-intensive settings and 580 TFlops in compute-bound settings. Now, we're pushing these numbers even further, reaching up to 660 TFlops. In this blog, we present a deep dive into the new kernel, explaining the optimizations and techniques behind this performance boost. We'll first explain why the MLA kernel is compute-bound despite being a decoding-stage attention kernel, then discuss our high-level kernel schedule design, and finally cover the technical details of the new kernel. ## A Theoretical Analysis of the MLA Algorithm GPU kernels can be classified as either compute-bound (limited by floating-point operations per second, FLOPs) or memory-bound (limited by memory bandwidth). To identify the kernel's bottleneck, we calculate the ratio of FLOPs to memory bandwidth (FLOPs/byte) and compare it with the GPU's capacity. Assume the number of q heads is $h_q$, the number of q tokens per request is $s_q$ (should be 1 if MTP / speculative decoding is disabled), the number of kv tokens per request is $s_k\ (s_k \gg h_q s_q)$, and the head dimensions of K and V are $d_k$ and $d_v$ respectively. The number of FLOPs is roughly $2 (h_q s_q \cdot d_k \cdot s_k + h_q s_q \cdot s_k \cdot d_v) = 2 h_q s_q s_k (d_k+d_v)$, and the memory access volume (in bytes) is $\mathop{\text{sizeof}}(\text{bfloat16}) \times (h_q s_q d_k + s_k d_k + h_q s_q d_v) \approx 2s_k d_k$. Thus, the compute-memory ratio is $h_q s_q \cdot \frac{d_k+d_v}{d_k} \approx 2 h_q s_q$. An NVIDIA H800 SXM5 GPU has a peak memory bandwidth of 3.35 TB/s and peak FLOPs of 990 TFlops. However, due to throttling (reducing to ~1600 MHz in our case), the practical peak FLOPs drops to ~865 TFlops. Therefore, when $h_qs_q \ge \frac{1}{2} \cdot \frac{865}{3.35} = 128$, the kernel is compute-bound; otherwise, it's memory-bound. According to [the overview of DeepSeek's Online Inference System](https://github.com/deepseek-ai/open-infra-index/blob/main/202502OpenSourceWeek/day_6_one_more_thing_deepseekV3R1_inference_system_overview.md), we don't use Tensor Parallel for decoding instances, meaning $h_q$ is 128 and the kernel is compute-bound. Thus, we need to optimize the kernel for compute-bound settings. ## High-Level Design of the New Kernel To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule." [FlashAttention-3's paper](https://arxiv.org/abs/2407.08608) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possibility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation. (You might pause here to ponder - perhaps you can find a better solution than ours!) Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows: 0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0). 1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$. 2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$. 3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$. 4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$. 5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$. 6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$. 7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$. 8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$. 9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$. 10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$. 11. [0] Update $\vec o_L \gets \vec o_L \cdot scale_1 + \vec p_1 V_{1L}$. Note: We assume one q head for simplicity, so $\vec q$ and $\vec o$ are vectors. Bracketed numbers indicate the warpgroup performing the operation. Assume $\vec o_L$ resides in warpgroup 0's register and $\vec o_R$ resides in warpgroup 1's register. This schedule can be viewed as a "ping-pong" variant using one output matrix—we call it "seesaw" scheduling. It's mathematically equivalent to FlashAttention's online softmax algorithm. This schedule allows us to overlap CUDA Core and Tensor Core operations by interleaving the two warpgroups, and also allows us to overlap memory access with computation since we can launch the corresponding Tensor Memory Accelerator (TMA) instructions right after data is no longer needed. The complete schedule is shown below (remember that in MLA, $K$ and $V$ are the same with different names): ![MLA Kernel Sched](assets/MLA%20Kernel%20Sched.drawio.svg) ## Discussion of Technical Details This section covers technical details of the new kernel. First, although the kernel targets compute-bound scenarios (where memory bandwidth isn't the bottleneck), we can't ignore memory latency. If the data is not ready when we want to use it, we have to wait. To solve this problem, we employ the following techniques: - **Fine-grained TMA copy - GEMM pipelining:** For a $64 \times 576$ K block, we launch 9 TMA copies (each moving a $64 \times 64$ block). GEMM operations begin as soon as each TMA copy completes (When the first TMA copy is done, we can start the first GEMM operation, and so on), improving memory latency tolerance. - **Cache hints:** Using `cute::TMA::CacheHintSm90::EVICT_FIRST` for TMA copies improves L2 cache hit rates, as shown by experiments. These optimizations achieve up to 80% Tensor Core utilization (of the throttled theoretical peak) and 3 TB/s memory bandwidth on an H800 SXM5 GPU. While slightly slower (~2%) than the old ping-pong buffer version in memory-bound settings, this is acceptable. Other performance improvements include: - **Programmatic Dependent Launch.** We use [programmatic dependent launch](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) to overlap `splitkv_mla` and `combine` kernels. - **Tile Scheduler.** We implement a tile scheduler to allocate jobs (requests and blocks) to SMs. This ensures a balanced load across SMs. ## Acknowledgements FlashMLA's algorithm and scheduling are inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work. ## Citation ```bibtex @misc{flashmla2025, title={FlashMLA: Efficient MLA decoding kernels}, author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, } ``` ================================================ FILE: docs/20250929-hopper-fp8-sparse-deep-dive.md ================================================ # A Deep Dive Into The Flash MLA FP8 Decoding Kernel on Hopper With the release of DeepSeek-V3.2, we have doubled the context length of our models from 64K tokens to 128K tokens. This puts significant pressure on GPU memory (a single request with 128K tokens requires a KVCache of size $576 \times 2 \times 62 \times 128 \times 1024 = 8.72\ \mathrm{GiB}$), which can lead to out-of-memory (OOM) errors or under-utilized GPUs due to small batch sizes. To address this, we introduced FP8 KVCache for DeepSeek-V3.2. However, writing a high-performance decoding kernel is challenging due to the need for dequantization and its sparse memory access patterns. In this blog, we share the story behind our new FP8 sparse decoding kernel for Hopper GPUs. We will first explain our FP8 KVCache format, then provide a theoretical analysis of clock cycles, and finally detail the techniques used in our new kernel. ## The FP8 KVCache Format Recall that the decoding phase of the Multi-head Latent Attention (MLA) algorithm operates similarly to Multi-Query Attention (MQA), with 128 query heads and 1 key head, where `head_dim_k = 576` and `head_dim_v = 512` respectively. To reduce the size of the KVCache while maintaining accuracy, we use a fine-grained quantization method. Specifically, we apply tile-level quantization (with a tile size of $1 \times 128$) to the first 512 elements in each token's KV Cache. This results in 512 `float8_e4m3` values and 4 `float32` scale factors. For the remaining 64 elements (the RoPE part), we do not apply quantization as they are sensitive to precision loss. Therefore, in GPU memory, each token's KVCache occupies 656 bytes, consisting of 512 `float8_e4m3`s, 4 `float32`s, and 64 `bfloat16`s. Inside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bfloat16`s. We then concatenate them with the 64 original `bfloat16` values from the RoPE part. Finally, we perform the MQA calculation using matrix multiplication-add (MMA) operations in `bfloat16` precision (i.e., the inputs to the MMAs are in `bfloat16` and the outputs are in `float32`. This applies to both the QK gemm and the attention-score-V gemm). ## Theoretical Analysis of Clock Cycles The main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up. The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). In our kernel, each CTA runs on one SM, and each SM is only mapped to one CTA. If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. However, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps: 1. Convert `float8_e4m3` to `half` 2. Convert `half` to `float32` 3. Convert `float32` to `bfloat16` 4. Multiply the converted `bfloat16` value by the `float32` scale factor According to [NVIDIA's documentation](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#throughput-of-native-arithmetic-instructions), we need at least $(\frac{1}{64} + \frac{1}{64} + \frac{1}{16} + \frac{1}{256}) \times 512 \approx 50$ cycles for dequantizing each token! This is significantly more than the 34 cycles required for the MMA operations, meaning the kernel is **dequantization-bound**. If left unaddressed, dequantization would become the performance bottleneck, leaving the powerful Tensor Cores underutilized. ## Crossover Before we continue, it's important to note a key fact: every query head within the same query token attends to the same key heads, because this is Multi-Query Attention (MQA). Recall that each CTA processes 64 query heads, while DeepSeek-V3.2 has a total of 128 query heads. If we can find a way to "share" the dequantized K/V values between two CTAs that are processing different sets of query heads, then each CTA would only need to dequantize **half** of the KV cache – which is fantastic! We call this method "crossover", since the idea was actually inspired by [Chromosomal crossover](https://en.wikipedia.org/wiki/Chromosomal_crossover) during [Meiosis](https://en.wikipedia.org/wiki/Meiosis). The next question is, how do we implement this in CUDA? Before NVIDIA's Hopper architecture, the only options for data exchange between CTAs were global memory or the L2 cache, which are slow. However, the powerful Distributed Shared Memory gave us a new solution. ## Distributed Shared Memory to the Rescue Distributed Shared Memory (DSM) is a new feature introduced with the Hopper architecture, alongside the CTA Cluster (thread block cluster). CTAs within the same cluster can directly access each other's shared memory. For more details, you can refer to [NVIDIA Hopper Architecture In-Depth](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). Here is how we use it: We launch CTAs in clusters of size 2. Each CTA within a cluster is responsible for 64 query heads from the same query token. Each CTA performs the following steps: 1. Loads *half* of the quantized K/V from global memory. We use a wide `__ldg` load with a width of 128 bits to improve performance. 2. Dequantizes its assigned half on the CUDA Cores. 3. Stores the dequantized K/V into its own shared memory. 4. Simultaneously uses `st.async` to write the dequantized K/V into the shared memory of the other CTA in the cluster. For synchronization between these operations, we rely on the [cluster transaction barrier](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/), another powerful programming primitive available in CTA Clusters. After the data exchange is complete, each CTA has the *full* set of dequantized K and V values available in its own shared memory, which it can then use to perform the MMA operations. ## Performance Using these techniques, we achieved 410 TFLOPS in a compute-bound configuration (batch_size=128, num_heads=128, s_q=2, topk=2048) on H800 SXM5 GPUs. This is a significant improvement over the 250 TFLOPS achieved by our previous FP8 sparse decoding kernel without the crossover technique. Although this number is still below the 640 TFLOPS peak of our previous bfloat16 dense decoding kernel, one reason is that it's a **sparse** kernel, and its topk is only 2048. With a smaller topk, the relative overhead of the kernel's prologue and epilogue becomes larger compared with dense decoding with long context length. If we set topk to a larger value, such as 32768, this kernel can achieve up to 460 TFLOPS. From another perspective, the execution time of this kernel in the configuration mentioned above is comparable to that of the dense decoding kernel when the sequence length is around 3000. When the sequence length exceeds 3000, the performance advantage of our new kernel becomes even more significant. This also highlights the effectiveness of our DeepSeek Sparse Attention algorithm. ================================================ FILE: flash_mla/__init__.py ================================================ __version__ = "1.0.0" from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_mla_sparse_fwd ) __all__ = [ "get_mla_metadata", "flash_mla_with_kvcache", "flash_attn_varlen_func", "flash_attn_varlen_qkvpacked_func", "flash_attn_varlen_kvpacked_func", "flash_mla_sparse_fwd" ] ================================================ FILE: flash_mla/flash_mla_interface.py ================================================ from typing import Optional, Tuple import dataclasses import torch import flash_mla.cuda as flash_mla_cuda @dataclasses.dataclass class FlashMLASchedMeta: """ A class that stores the tile scheduler metadata of FlashMLA """ @dataclasses.dataclass class Config: b: int s_q: int h_q: int page_block_size: int h_k: int causal: bool is_fp8_kvcache: bool topk: Optional[int] extra_page_block_size: Optional[int] extra_topk: Optional[int] have_initialized: bool = False config: Optional[Config] = None tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. def get_mla_metadata( *args, **kwargs ) -> Tuple[FlashMLASchedMeta, None]: """ Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. Arguments: This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. Return: A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. """ return FlashMLASchedMeta(), None def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: Optional[torch.Tensor], cache_seqlens: Optional[torch.Tensor], head_dim_v: int, tile_scheduler_metadata: FlashMLASchedMeta, num_splits: None = None, softmax_scale: Optional[float] = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, attn_sink: Optional[torch.Tensor] = None, extra_k_cache: Optional[torch.Tensor] = None, extra_indices_in_kvcache: Optional[torch.Tensor] = None, topk_length: Optional[torch.Tensor] = None, extra_topk_length: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. head_dim_v: Head_dim of v. Must be 512 sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. num_splits_placeholder: must be "None" (to be compatible with the old interface). softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). causal: bool. Whether to apply causal attention mask. Only valid for dense attention is_fp8_kvcache: bool. indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), where t is the k-th token of the j-th q-sequence in the i-th batch. attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: head_dim should be 576 while head_dim_v should be 512. In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ sched_meta = tile_scheduler_metadata indices_in_kvcache = indices assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" assert num_splits is None, "num_splits must be None" topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if not sched_meta.have_initialized: # Sanity check. We only perform sanity check during the first invocation to save CPU time. if indices_in_kvcache is not None: assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" # Initialize the tile scheduler metadata during the first invocation. sched_meta.have_initialized = True sched_meta.config = FlashMLASchedMeta.Config( q.shape[0], q.shape[1], q.shape[2], k_cache.shape[1], k_cache.shape[2], causal, is_fp8_kvcache, topk, extra_k_page_block_size, extra_topk, ) else: # Check whether the input arguments are consistent with sched_meta helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." assert sched_meta.config is not None assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg if topk is not None: # Sparse attention assert not causal, "causal must be False when sparse attention is enabled" assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( q, k_cache, indices_in_kvcache, topk_length, attn_sink, sched_meta.tile_scheduler_metadata, sched_meta.num_splits, extra_k_cache, extra_indices_in_kvcache, extra_topk_length, head_dim_v, softmax_scale ) else: # Dense attention assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, sched_meta.tile_scheduler_metadata, sched_meta.num_splits ) sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata sched_meta.num_splits = new_num_splits return (out, lse) def flash_mla_sparse_fwd( q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int = 512, attn_sink: Optional[torch.Tensor] = None, topk_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel Args: q: [s_q, h_q, d_qk], bfloat16 kv: [s_kv, h_kv, d_qk], bfloat16 indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv sm_scale: float d_v: The dimension of value vectors. Can only be 512 attn_sink: optional, [h_q], float32. If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). This argument has no effect on lse and max_logits. topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. Returns: (output, max_logits, lse) Please refer to tests/ref.py for the precise definitions of these parameters. - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, log-sum-exp of attention scores """ results = flash_mla_cuda.sparse_prefill_fwd( q, kv, indices, sm_scale, d_v, attn_sink, topk_length ) return results def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: qo_total_len, num_qo_heads, head_dim_qk = q.shape kv_total_len, num_kv_heads, head_dim_vo = v.shape mask_mode_code = 1 if causal else 0 if softmax_scale is None: softmax_scale = head_dim_qk ** (-0.5) if out is None: out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) if lse is None: # Make lse contiguous on seqlen dim lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) flash_mla_cuda.dense_prefill_fwd( workspace_buffer, q, k, v, cu_seqlens_qo, cu_seqlens_kv, out, lse, mask_mode_code, softmax_scale, max_seqlen_qo, max_seqlen_kv, is_varlen, ) return out, lse def _flash_attn_varlen_backward( do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, lse: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: qo_total_len, num_qo_heads, head_dim_qk = q.shape kv_total_len, num_kv_heads, head_dim_vo = v.shape # TODO: fix bwd GQA if num_qo_heads != num_kv_heads: raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") mask_mode_code = 1 if causal else 0 if softmax_scale is None: softmax_scale = head_dim_qk ** (-0.5) if dq is None: dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) if dk is None: dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) if dv is None: dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 bs = cu_seqlens_qo.shape[0] - 1 workspace_bytes = 0 workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) flash_mla_cuda.dense_prefill_bwd( workspace_buffer, do, q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv, dq, dk, dv, mask_mode_code, softmax_scale, max_seqlen_qo, max_seqlen_kv, is_varlen, ) return dq, dk, dv class FlashAttnVarlenFunc(torch.autograd.Function): def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: out, lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal=causal, softmax_scale=softmax_scale, is_varlen=is_varlen, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) ctx.max_seqlen_qo = max_seqlen_qo ctx.max_seqlen_kv = max_seqlen_kv ctx.causal = causal ctx.softmax_scale = softmax_scale ctx.is_varlen = is_varlen return out, lse def backward( ctx, do: torch.Tensor, dlse: torch.Tensor, ): del dlse # LSE doesn't support backward currently q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors dq, dk, dv = _flash_attn_varlen_backward( do, q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, causal=ctx.causal, softmax_scale=ctx.softmax_scale, is_varlen=ctx.is_varlen, ) return dq, dk, dv, None, None, None, None, None, None, None def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) def flash_attn_varlen_qkvpacked_func( qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, head_dim_qk: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal, softmax_scale, is_varlen, ) def flash_attn_varlen_kvpacked_func( q: torch.Tensor, kv: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, head_dim_qk: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) ================================================ FILE: setup.py ================================================ import os from pathlib import Path from datetime import datetime import subprocess from setuptools import setup, find_packages from torch.utils.cpp_extension import ( BuildExtension, CUDAExtension, IS_WINDOWS, CUDA_HOME ) def is_flag_set(flag: str) -> bool: return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] def get_features_args(): features_args = [] if is_flag_set("FLASH_MLA_DISABLE_FP16"): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args def get_arch_flags(): # Check NVCC Version # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support" nvcc_version = subprocess.check_output( [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT ).decode('utf-8') nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip() major, minor = map(int, nvcc_version_number.split('.')) print(f'Compiling using NVCC {major}.{minor}') DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") if major < 12 or (major == 12 and minor <= 8): assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this arch_flags = [] if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"]) if not DISABLE_SM90: arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) return arch_flags def get_nvcc_thread_args(): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return ["--threads", nvcc_threads] subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) this_dir = os.path.dirname(os.path.abspath(__file__)) if IS_WINDOWS: cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] else: cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla.cuda", sources=[ # API "csrc/api/api.cpp", # Misc kernels for decoding "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", "csrc/smxx/decode/combine/combine.cu", # sm90 dense decode "csrc/sm90/decode/dense/instantiations/fp16.cu", "csrc/sm90/decode/dense/instantiations/bf16.cu", # sm90 sparse decode "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", # sm90 sparse prefill "csrc/sm90/prefill/sparse/fwd.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", # sm100 dense prefill & backward "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", # sm100 sparse prefill "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", # sm100 sparse decode "csrc/sm100/decode/head64/instantiations/v32.cu", "csrc/sm100/decode/head64/instantiations/model1.cu", "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), "nvcc": [ "-O3", "-std=c++20", "-DNDEBUG", "-D_USE_MATH_DEFINES", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use", "-lineinfo", "--source-in-ptx", ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), }, include_dirs=[ Path(this_dir) / "csrc", Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) try: cmd = ['git', 'rev-parse', '--short', 'HEAD'] rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() except Exception as _: now = datetime.now() date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") rev = '+' + date_time_str setup( name="flash_mla", version="1.0.0" + rev, packages=find_packages(include=['flash_mla']), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, ) ================================================ FILE: tests/kernelkit/.gitignore ================================================ build *.so *.egg-info/ __pycache__/ dist/ /.vscode .cache /temp /profiles ================================================ FILE: tests/kernelkit/__init__.py ================================================ from . import bench from . import compare from . import generate from . import precision from . import utils from .bench import bench_kineto, bench_by_cuda_events from .compare import get_cos_diff, check_is_bitwise_equal, check_is_allclose, check_is_bitwise_equal_comparator, check_is_allclose_comparator from .generate import gen_non_contiguous_randn_tensor, gen_non_contiguous_tensor, non_contiguousify from .precision import LowPrecisionMode, is_low_precision_mode, optional_cast_to_bf16_and_cast_back from .utils import colors, cdiv, is_using_profiling_tools, set_random_seed, Counter ================================================ FILE: tests/kernelkit/bench.py ================================================ from typing import Tuple, List, Callable, Union, Dict, overload import dataclasses import torch import triton from .utils import is_using_profiling_tools class empty_suppress: def __enter__(self): return self def __exit__(self, *_): pass @triton.jit def profiler_range_start_marker_kernel(): pass def _run_profiler_range_start_marker_kernel(): profiler_range_start_marker_kernel[(1,)]() @dataclasses.dataclass class BenchKinetoRawResult: """ A struct holding the result of `bench_kineto` """ is_using_nsys: bool num_tests: int time_ranges: Dict[str, List[Tuple[float, float]]] def _get_matched_kernel_name(self, name_substr: str, allow_no_match: bool = False, allow_multiple_match: bool = False) -> List[str]: matched_names = [name for name in self.time_ranges.keys() if name_substr in name] if not allow_no_match and len(matched_names) == 0: all_kernel_names_str = '\n - ' + '\n - '.join(self.time_ranges.keys()) raise ValueError(f"Error: No kernel name matched for substring {name_substr}.\nAvailable kernels are: {all_kernel_names_str}") if not allow_multiple_match and len(matched_names) > 1: raise ValueError(f"Error: Multiple kernel matched for substring {name_substr}: {', '.join(matched_names)}") return matched_names def get_kernel_names(self) -> List[str]: return list(self.time_ranges.keys()) def get_kernel_times(self, kernel_names_substr: List[str], allow_indivisible_run_count: bool = False, allow_missing: bool = False, allow_multiple_match: bool = False, return_avg_individual_run: bool = False) -> List[float]: """ Get the average each-run time usage of each kernel provided in `kernel_names` If return_avg_individual_run is False, return sum(time) / num_tests, else return sum(time) / len(time) If is_using_profiling_tools (which is conflict with bench_kineto), return a series of 1 seconds """ if is_using_profiling_tools(): return [1 for _ in range(len(kernel_names_substr))] result = [] for substr in kernel_names_substr: matched_names = self._get_matched_kernel_name(substr, allow_no_match=allow_missing, allow_multiple_match=allow_multiple_match) if len(matched_names) == 0: assert allow_missing result.append(0) else: time_usage_sum = 0 run_cnt_sum = 0 for matched_name in matched_names: run_cnt = len(self.time_ranges[matched_name]) if not allow_indivisible_run_count and run_cnt % self.num_tests != 0: raise RuntimeError(f"Error: the number of runs for kernel {matched_name} ({run_cnt}) is indivisible by `num_tests` ({self.num_tests})") time_usage_sum += sum([end-start for (start, end) in self.time_ranges[matched_name]]) run_cnt_sum += run_cnt denominator = run_cnt_sum if return_avg_individual_run else self.num_tests result.append(time_usage_sum / denominator) return result def get_kernel_time(self, kernel_name_substr: str) -> float: return self.get_kernel_times([kernel_name_substr])[0] def get_e2e_time(self, start_kernel_name_substr: str, end_kenrel_name_substr: str) -> float: """ Get the end-to-end time usage for a sequence of kernels defined as "last kernel end time" - "first kernel start time" If is_using_profiling_tools (which is conflict with bench_kineto), return 1 second """ if is_using_profiling_tools(): return 1 start_kernel_name = self._get_matched_kernel_name(start_kernel_name_substr)[0] end_kernel_name = self._get_matched_kernel_name(end_kenrel_name_substr)[0] num_start_kernels = len(self.time_ranges[start_kernel_name]) num_end_kernels = len(self.time_ranges[end_kernel_name]) if num_start_kernels%self.num_tests != 0: raise RuntimeError(f"Error: the number of runs for kernel {start_kernel_name} ({num_start_kernels}) is indivisible by `num_tests` ({self.num_tests})") if num_end_kernels%self.num_tests != 0: raise RuntimeError(f"Error: the number of runs for kernel {end_kernel_name} ({num_end_kernels}) is indivisible by `num_tests` ({self.num_tests})") time_spans = [] for i in range(self.num_tests): end_time = self.time_ranges[end_kernel_name][(i+1)*(num_end_kernels//self.num_tests)-1][1] start_time = self.time_ranges[start_kernel_name][i*(num_start_kernels//self.num_tests)][0] time_spans.append((start_time, end_time)) result = sum([end-start for (start, end) in time_spans]) / self.num_tests return result def bench_kineto(fn: Callable, num_tests: int = 30, flush_l2: bool = True) -> BenchKinetoRawResult: """ Run `fn` for `num_tests` times under `bench_kineto` (CUPTI), and returns a BenchKinetoRawResult """ using_nsys = is_using_profiling_tools() # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle flush_l2_size = int(8e9 // 4) schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() with profiler: for i in range(2): if i == 1 and not using_nsys: _run_profiler_range_start_marker_kernel() # This marks the start of the profiling range for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() enable_nvtx_range = i == 1 and _ == num_tests-1 if enable_nvtx_range: torch.cuda.nvtx.range_push("profile_target") fn() if enable_nvtx_range: torch.cuda.nvtx.range_pop() if not using_nsys: if i == 0: torch.cuda.synchronize() profiler.step() if using_nsys: return BenchKinetoRawResult(True, num_tests, {}) from torch.autograd.profiler_util import EventList, FunctionEvent # pylint: disable=import-outside-toplevel events: EventList = profiler.events() # type: ignore # Filter out all events that are not function events events: List[FunctionEvent] = [event for event in events if isinstance(event, FunctionEvent)] # Filter out all events before the range marker for idx, event in enumerate(events): if event.name == "profiler_range_start_marker_kernel": events = events[idx+1:] break else: raise RuntimeError("Could not find profiler range start marker kernel event") # Get time ranges of each kernel kernel_times = {} for event in events: kernel_name = event.name if kernel_name not in kernel_times: kernel_times[kernel_name] = [] kernel_times[kernel_name].append((event.time_range.start/1e6, event.time_range.end/1e6)) return BenchKinetoRawResult(False, num_tests, kernel_times) @overload def bench_by_cuda_events(kernels: List[Callable], num_warmups_each: int, num_runs_each: int) -> List[float]: ... @overload def bench_by_cuda_events(kernels: Callable, num_warmups_each: int, num_runs_each: int) -> float: ... def bench_by_cuda_events(kernels: Union[List[Callable], Callable], num_warmups_each: int, num_runs_each: int) -> Union[List[float], float]: buf_for_l2_clear = torch.empty(int(256e6//4), dtype=torch.int32, device='cuda') is_kernel_single_callable = isinstance(kernels, Callable) if is_kernel_single_callable: kernels = [kernels] torch.cuda.synchronize() for i in range(num_warmups_each): for kernel in kernels: kernel() if i == 0: # Ensure the first run is successful try: torch.cuda.synchronize() except Exception as e: print(f"Kernel {kernel.__name__} failed on warmup run {i}: {e}") return [] start_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels] end_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels] for i in range(num_runs_each): for j, kernel in enumerate(kernels): buf_for_l2_clear.random_() if i == num_runs_each-1: torch.cuda.nvtx.range_push("profile_target") start_events[j][i].record() kernel() end_events[j][i].record() if i == num_runs_each-1: torch.cuda.nvtx.range_pop() torch.cuda.synchronize() time_usages = [ sum([start_events[j][i].elapsed_time(end_events[j][i])*1e-3 for i in range(num_runs_each)]) / num_runs_each for j in range(len(kernels)) ] if is_kernel_single_callable: time_usages = time_usages[0] return time_usages ================================================ FILE: tests/kernelkit/compare.py ================================================ from typing import List import torch def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor): """ Return if two tensors are bitwise equal Return a bool if avoid_sync is False, else return a tensor """ assert ans.shape == ref.shape, "Shape mismatch" torch.all(torch.eq(ans, ref), out=result) def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool: is_bitwise_equal = torch.equal(ans, ref) if not quiet and not is_bitwise_equal: print(f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}") return is_bitwise_equal def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float: """ Calculate the cosine diff between two tensors Return a float if avoid_sync is False, else return a tensor """ ans, ref = ans.double(), ref.double() if (ref*ref).sum().item() < 1e-12: return 0 denominator = (ans*ans + ref*ref).sum().item() sim = 2 * (ans*ref).sum().item() / denominator return 1 - sim def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool: """ Check if two tensors are close enough Return a bool if avoid_sync is False, else return a tensor """ assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" assert ans.dtype == ref.dtype, f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}" ans = ans.clone().to(torch.float) ref = ref.clone().to(torch.float) def report_err(*args, **kwargs): if not quiet: print(*args, **kwargs) # Deal with anomalies def deal_with_anomalies(val: float): ref_mask = (ref == val) if (val == val) else (ref != ref) ans_mask = (ans == val) if (val == val) else (ans != ans) ref[ref_mask] = 0.0 ans[ans_mask] = 0.0 if not torch.equal(ref_mask, ans_mask): report_err(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") return False return True anomalies_check_passed = True anomalies_check_passed &= deal_with_anomalies(float("inf")) anomalies_check_passed &= deal_with_anomalies(float("-inf")) anomalies_check_passed &= deal_with_anomalies(float("nan")) cos_diff = get_cos_diff(ans, ref) raw_abs_err = torch.abs(ans-ref) raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: result = [] for size in t.shape[::-1]: result.append(pos % size) pos = pos // size assert pos == 0 return result[::-1] report_err(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") report_err(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") report_err(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") return False else: if abs(cos_diff) > cos_diff_tol: report_err(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") return False return True def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol)) ================================================ FILE: tests/kernelkit/generate.py ================================================ import torch def _get_new_non_contiguous_tensor_shape(shape): """ Get the expanded shape for a non-contiguous tensor. The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1 """ return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)] def gen_non_contiguous_randn_tensor(shape, *args, **kwargs): new_shape = _get_new_non_contiguous_tensor_shape(shape) base_tensor = torch.randn(new_shape, *args, **kwargs) slices = [slice(0, dim) for dim in shape] return base_tensor[slices] def gen_non_contiguous_tensor(shape, *args, **kwargs): new_shape = _get_new_non_contiguous_tensor_shape(shape) base_tensor = torch.empty(new_shape, *args, **kwargs) slices = [slice(0, dim) for dim in shape] return base_tensor[slices] def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor: new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device) new_tensor[:] = tensor return new_tensor ================================================ FILE: tests/kernelkit/precision.py ================================================ import torch _is_low_precision_mode_stack = [] class LowPrecisionMode: def __init__(self, enabled: bool = True): self.enabled = enabled def __enter__(self): global _is_low_precision_mode_stack _is_low_precision_mode_stack.append(self.enabled) def __exit__(self, exc_type, exc_value, traceback): global _is_low_precision_mode_stack _is_low_precision_mode_stack.pop() def is_low_precision_mode() -> bool: global _is_low_precision_mode_stack if len(_is_low_precision_mode_stack) == 0: return False return _is_low_precision_mode_stack[-1] def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor: assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting." if is_low_precision_mode(): tensor_bf16 = tensor.to(torch.bfloat16) tensor_fp32 = tensor_bf16.to(torch.float32) return tensor_fp32 else: return tensor ================================================ FILE: tests/kernelkit/utils.py ================================================ import os import functools colors = { 'RED_FG': '\033[31m', 'GREEN_FG': '\033[32m', 'CYAN_FG': '\033[36m', 'GRAY_FG': '\033[90m', 'YELLOW_FG': '\033[33m', 'RED_BG': '\033[41m', 'GREEN_BG': '\033[42m', 'CYAN_BG': '\033[46m', 'YELLOW_BG': '\033[43m', 'GRAY_BG': '\033[100m', 'CLEAR': '\033[0m' } def cdiv(a: int, b: int) -> int: return (a + b - 1) // b @functools.lru_cache() def is_using_profiling_tools() -> bool: """ Return whether we are running under profiling tools like nsys or ncu NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it """ is_using_nsys = os.environ.get('NSYS_PROFILING_SESSION_ID') is not None is_using_ncu = os.environ.get('NV_COMPUTE_PROFILER_PERFWORKS_DIR') is not None is_using_compute_sanitizer = os.environ.get('NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN') is not None return is_using_nsys or is_using_ncu or is_using_compute_sanitizer def set_random_seed(seed: int): import random import numpy as np import torch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) class Counter: def __init__(self): self.count = 0 def next(self) -> int: self.count += 1 return self.count - 1 ================================================ FILE: tests/lib.py ================================================ import dataclasses import os import enum from typing import List, Optional import random import torch import kernelkit as kk import flash_mla import quant class TestTarget(enum.Enum): FWD = 0 DECODE = 1 @dataclasses.dataclass class ExtraTestParamForDecode: b: int is_varlen: bool have_zero_seqlen_k: bool extra_s_k: Optional[int] = None extra_topk: Optional[int] = None block_size: int = 64 extra_block_size: Optional[int] = None have_extra_topk_length: bool = False @dataclasses.dataclass class TestParam: s_q: int s_kv: int topk: int h_q: int = 128 h_kv: int = 1 d_qk: int = 512 d_v: int = 512 seed: int = -1 # -1: to be filled automatically check_correctness: bool = True is_all_indices_invalid: bool = False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647) num_runs: int = 10 have_attn_sink: bool = False have_topk_length: bool = False decode: Optional[ExtraTestParamForDecode] = None @dataclasses.dataclass class RawTestParamForDecode: """ "Flattened" test parameters for decoding test In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test. """ b: int h_q: int s_q: int h_kv: int s_kv: int is_varlen: bool topk: int is_all_indices_invalid: bool = False have_zero_seqlen_k: bool = False have_topk_length: bool = False enable_attn_sink: bool = True extra_s_k: Optional[int] = None extra_topk: Optional[int] = None block_size: int = 64 extra_block_size: Optional[int] = None have_extra_topk_length: bool = False d_qk: int = 576 # Q/K head dim (= dv + RoPE dim) d_v: int = 512 # V head dim check_correctness: bool = True num_runs: int = 10 seed: int = -1 def to_test_param(self) -> TestParam: return TestParam( self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v, self.seed, self.check_correctness, self.is_all_indices_invalid, self.num_runs, self.enable_attn_sink, self.have_topk_length, decode = ExtraTestParamForDecode( self.b, self.is_varlen, self.have_zero_seqlen_k, self.extra_s_k, self.extra_topk, self.block_size, self.extra_block_size, self.have_extra_topk_length ) ) @dataclasses.dataclass class Testcase: p: TestParam dOut: torch.Tensor # [s_q, h_q, d_v] q: torch.Tensor # [s_q, h_q, d_qk] kv: torch.Tensor # [s_kv, h_kv, d_qk] indices: torch.Tensor # [s_q, h_kv, topk] sm_scale: float attn_sink: Optional[torch.Tensor] # [h_q] topk_length: Optional[torch.Tensor] # [s_q] def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor: """ Generate random permutations in batch The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds. Values within each row are unique. If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`. """ assert not torch.are_deterministic_algorithms_enabled() torch.use_deterministic_algorithms(True) perm_range_max = max(int(torch.max(perm_range).item()), perm_size) rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32) rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32) if len(paddings) == 1: res[res >= perm_range.view(batch_size, 1)] = paddings[0] else: fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32)) res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers) torch.use_deterministic_algorithms(False) return res def generate_testcase(t: TestParam) -> Testcase: kk.set_random_seed(t.seed) q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 q.clamp_(-10, 10) kv.clamp_(-10, 10) do.clamp_(-10, 10) invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647] indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk) if t.is_all_indices_invalid: all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2 indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate) indices = indices.to(q.device) attn_sink = None if t.have_attn_sink: attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) mask = torch.randn((t.h_q, ), dtype=torch.float32) attn_sink[mask < -0.5] = float("-inf") attn_sink[mask > +0.5] = float("+inf") topk_length = None if t.have_topk_length: topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk) q = kk.non_contiguousify(q) kv = kk.non_contiguousify(kv) do = kk.non_contiguousify(do) indices = kk.non_contiguousify(indices) return Testcase( p=t, dOut=do, q=q, kv=kv, indices=indices, sm_scale=0.5, # Otherwise dK is too small compared to dV attn_sink=attn_sink, topk_length=topk_length ) @dataclasses.dataclass class KVScope: t: TestParam cache_seqlens: torch.Tensor block_table: torch.Tensor blocked_k: torch.Tensor abs_indices: torch.Tensor indices_in_kvcache: torch.Tensor topk_length: Optional[torch.Tensor] blocked_k_quantized: Optional[torch.Tensor] = None def quant_and_dequant_(self): """ For FP8 cases, we need to quantize the KV cache for Flash MLA. Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error """ fp8_kvcache_layout = None if self.t.d_qk == 576: fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse elif self.t.d_qk == 512: assert self.abs_indices is not None fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse else: assert False self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout) blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout) self.blocked_k = blocked_k_dequantized def get_kvcache_for_flash_mla(self) -> torch.Tensor: """ Return the quantized blocked_k for Flash MLA """ assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`" return self.blocked_k_quantized def apply_perm(self, perm: torch.Tensor) -> "KVScope": """ Apply a batch permutation to this KVScope. Used for batch-invariance test """ new_kvscope = KVScope( self.t, self.cache_seqlens[perm], self.block_table[perm], self.blocked_k, self.abs_indices[perm], self.indices_in_kvcache[perm], self.topk_length[perm] if self.topk_length is not None else None, self.blocked_k_quantized ) return new_kvscope @dataclasses.dataclass class TestcaseForDecode: p: TestParam q: torch.Tensor # [b, s_q, h_q, d_qk] attn_sink: Optional[torch.Tensor] # [h_q] sm_scale: float kv_scope: KVScope extra_kv_scope: Optional[KVScope] def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: kk.set_random_seed(t.seed) assert t.h_q % t.h_kv == 0 assert t.decode is not None q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk)) q.clamp_(min=-1.0, max=1.0) attn_sink = None if t.have_attn_sink: attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) inf_mask = torch.randn((t.h_q, ), dtype=torch.float32) attn_sink[inf_mask > 0.5] = float("inf") attn_sink[inf_mask < -0.5] = float("-inf") def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope: b = t.decode.b # type: ignore cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu') if is_varlen: for i in range(b): cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q) if have_zero_seqlen: zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0 cache_seqlens_cpu[zeros_mask] = 0 max_seqlen_alignment = 4 * block_size max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment cache_seqlens = cache_seqlens_cpu.cuda() assert max_seqlen_pad % block_size == 0 block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1) blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10 blocked_k.clamp_(min=-1.0, max=1.0) abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32) if is_all_indices_invalid: abs_indices.fill_(-1) else: abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk) indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size) topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None # Mask nonused KV as NaN if have_topk_length: indices_in_kvcache_masked = indices_in_kvcache.clone() indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1 else: indices_in_kvcache_masked = indices_in_kvcache blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk) nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') nonused_indices_mask[indices_in_kvcache_masked] = False blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk) block_table = kk.non_contiguousify(block_table) abs_indices = kk.non_contiguousify(abs_indices) indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache) return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length) kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length) kv_scope0.quant_and_dequant_() if t.decode.extra_topk is not None: if t.decode.extra_s_k is None: t.decode.extra_s_k = t.decode.extra_topk*2 if t.decode.extra_block_size is None: t.decode.extra_block_size = t.decode.block_size kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length) kv_scope1.quant_and_dequant_() else: assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length kv_scope1 = None sm_scale = t.d_qk ** -0.55 q = kk.non_contiguousify(q) return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1) def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool): assert not return_p_sum return flash_mla.flash_mla_sparse_fwd( t.q, t.kv, t.indices, sm_scale=t.sm_scale, attn_sink=t.attn_sink, topk_length=t.topk_length ) def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits): assert p.decode is not None return flash_mla.flash_mla_with_kvcache( t.q, t.kv_scope.get_kvcache_for_flash_mla(), None, None, p.d_v, tile_scheduler_metadata, num_splits, t.sm_scale, False, True, t.kv_scope.indices_in_kvcache, t.attn_sink, t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None, t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None, t.kv_scope.topk_length, t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None ) @dataclasses.dataclass class FlopsAndMemVolStatistics: """ FLOPs and memory volume statistics for prefilling """ fwd_flop: float fwd_mem_vol: float def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics: total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item() indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv) if t.topk_length is not None: indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None] num_valid_indices = indices_valid_mask.sum().item() fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v) fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2 return FlopsAndMemVolStatistics( fwd_flop, fwd_mem_vol, ) @dataclasses.dataclass class FlopsAndMemVolStatisticsForDecode: """ FLOPs and memory volume statistics for decoding """ flop: float mem_vol: float def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode: assert p.decode b = p.decode.b def get_num_attended_tokens(kv_scope: KVScope) -> int: topk = kv_scope.indices_in_kvcache.shape[-1] if kv_scope.topk_length is None: return b * p.s_q * topk else: return int(kv_scope.topk_length.sum().item()) * p.s_q def get_num_retrieved_tokens(kv_scope: KVScope) -> int: if kv_scope.topk_length is None: indices = kv_scope.indices_in_kvcache else: indices = kv_scope.indices_in_kvcache.clone() batch, s_q, topk = indices.shape mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1) indices[mask] = -1 num_unique_tokens = indices.unique().numel() # type: ignore return num_unique_tokens num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v) kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache mem_vol = sum([ 2 * b * p.s_q * p.h_q * p.d_qk, # Q num_retrieved_tokens * kv_token_size, # K 2 * b * p.s_q * p.h_q * p.d_v, # O ]) return FlopsAndMemVolStatisticsForDecode( compute_flop, mem_vol ) def is_no_cooldown() -> bool: return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y'] ================================================ FILE: tests/quant.py ================================================ import enum from typing import Tuple import torch class FP8KVCacheLayout(enum.Enum): V32_FP8Sparse = 1 MODEL1_FP8Sparse = 2 def get_meta(self) -> Tuple[int, int, int, int, int]: # Return: (d, d_nope, d_rope, tile_size, num_tiles) return { FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4), FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7) }[self] def _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch.float32) -> torch.Tensor: return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) def quantize_k_cache( input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) kvcache_layout: FP8KVCacheLayout, ) -> torch.Tensor: """ Quantize the k-cache For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py """ d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() assert input_k_cache.shape[-1] == d num_blocks, block_size, h_k, _ = input_k_cache.shape assert h_k == 1 input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_elem_size = input_k_cache.element_size() if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: bytes_per_token = d_nope + num_tiles*4 + input_elem_size*d_rope result = torch.empty((num_blocks, block_size+1, bytes_per_token), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size, :] result_k_nope_part = result[..., :d_nope] result_k_scale_factor = result[..., d_nope: d_nope + num_tiles*4].view(torch.float32) result_k_rope_part = result[..., d_nope + num_tiles*4:].view(input_k_cache.dtype) result_k_rope_part[:] = input_k_cache[..., d_nope:] for tile_idx in range(0, num_tiles): cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size] cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope result = result.view(num_blocks, block_size, 1, -1) return result elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: bytes_per_token = d_nope + 2*d_rope + num_tiles + 1 size_per_block_padded = (block_size*bytes_per_token + 576-1) // 576 * 576 result = torch.empty((num_blocks, size_per_block_padded), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size*bytes_per_token] result_k_nope_rope_part = result[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope) result_k_nope = result_k_nope_rope_part[:, :, :d_nope] # [num_blocks, block_size, d_nope] result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view(input_k_cache.dtype) # [num_blocks, block_size, d_rope] result_k_scale_factor = result[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles] result_k_rope[:] = input_k_cache[..., d_nope:] for tile_idx in range(0, num_tiles): cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size] cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to(torch.float8_e8m0fnu) cur_scale_factors_inv = cur_scale_factors_inv.view(num_blocks, block_size, 1) cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) result_k_nope[:, :, tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope result = result.view(num_blocks, block_size, 1, -1) return result else: raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") def dequantize_k_cache( quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) kvcache_layout: FP8KVCacheLayout, ) -> torch.Tensor: """ De-quantize the k-cache """ d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() num_blocks, block_size, h_k, _ = quant_k_cache.shape assert h_k == 1 result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device) if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) input_nope = quant_k_cache[..., :d_nope] input_scale = quant_k_cache[..., d_nope:d_nope + num_tiles*4].view(torch.float32) input_rope = quant_k_cache[..., d_nope + num_tiles*4:].view(torch.bfloat16) result[..., d_nope:] = input_rope for tile_idx in range(0, num_tiles): cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) cur_scales = input_scale[..., tile_idx].unsqueeze(-1) result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: quant_k_cache = quant_k_cache.view(num_blocks, -1) # [num_blocks, ...] input_nope_rope = quant_k_cache[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope) input_nope = input_nope_rope[:, :, :d_nope] input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) input_scale = quant_k_cache[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles] result[..., d_nope:] = input_rope for tile_idx in range(0, num_tiles): cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.bfloat16) cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) result[..., tile_idx*tile_size: (tile_idx+1)*tile_size] = cur_nope * cur_scales else: raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") result = result.view(num_blocks, block_size, 1, d) return result def abs_indices2indices_in_kvcache( abs_indices: torch.Tensor, # [b, s_q, topk] block_table: torch.Tensor, # [b, /] block_size: int, ) -> torch.Tensor: """ Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel Equivalent to: b, s_q, topk = abs_indices.shape indices_in_kvcache = torch.empty_like(abs_indices) for i in range(b): cur_abs_indices = abs_indices[i, :, :].clone() # [s_q, topk] invalid_mask = cur_abs_indices == -1 cur_abs_indices[invalid_mask] = 0 cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size cur_indices_in_kvcache[invalid_mask] = -1 indices_in_kvcache[i] = cur_indices_in_kvcache return indices_in_kvcache """ b, s_q, topk = abs_indices.shape _, max_blocks_per_seq = block_table.shape abs_indices = abs_indices.clone() invalid_mask = abs_indices == -1 abs_indices[invalid_mask] = 0 real_block_idxs = block_table.view(-1).index_select(0, (abs_indices//block_size + torch.arange(0, b).view(b, 1, 1)*max_blocks_per_seq).view(-1)) indices_in_kvcache = real_block_idxs.view(b, s_q, topk)*block_size + abs_indices%block_size indices_in_kvcache[invalid_mask] = -1 return indices_in_kvcache ================================================ FILE: tests/ref.py ================================================ from typing import Optional, Tuple import torch from lib import TestParam, Testcase, TestcaseForDecode, KVScope def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor: if lse1 is None: return lse0 else: return torch.logsumexp( torch.stack([ lse0.view(s_q, h_q), lse1.broadcast_to(s_q, h_q) ], dim=0), dim=0 ) def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: - o: [s_q, h_q, dv] - o_fp32: [s_q, h_q, dv] - max_logits: [s_q, h_q] - lse: [s_q, h_q] """ indices = t.indices.clone().squeeze(1) if t.topk_length is not None: mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1) # [s_q, topk] indices[mask] = -1 invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk] indices[invalid_mask] = 0 q = t.q.float() gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float() # [s_q, topk, d_qk] P = (q @ gathered_kv.transpose(1, 2)) # [s_q, h_q, topk] P *= t.sm_scale P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf") orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q] max_logits = P.max(dim=-1).values # [s_q, h_q] lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q) if not torch.is_inference_mode_enabled(): lse_for_o = lse_for_o.clone() lse_for_o[lse_for_o == float("-inf")] = float("+inf") # So that corresponding O will be 0 s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1)) out = s_for_o @ gathered_kv[..., :p.d_v] # [s_q, h_q, dv] lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q] orig_lse[lonely_q_mask] = float("+inf") return (out.to(torch.bfloat16), out, max_logits, orig_lse) def ref_sparse_attn_decode( p: TestParam, t: TestcaseForDecode ) -> Tuple[torch.Tensor, torch.Tensor]: """ A reference implementation of sparse decoding attention in PyTorch """ assert p.h_kv == 1 assert p.decode is not None b = p.decode.b def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]: assert kv_scope.indices_in_kvcache is not None topk = kv_scope.indices_in_kvcache.size(-1) indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk) # [b, s_q, topk, d] invalid_mask = kv_scope.indices_in_kvcache == -1 if kv_scope.topk_length is not None: invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1) return gathered_kv, invalid_mask gathered_kv, invalid_mask = process_kv_scope(t.kv_scope) if t.extra_kv_scope is not None: gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope) gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) # [b, s_q, topk+extra_topk, d] invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) # [b, s_q, topk+extra_topk] gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float() gathered_kv[gathered_kv != gathered_kv] = 0.0 q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk) attn_weight = q @ gathered_kv.transpose(-1, -2) # [t.b*t.s_q, t.h_q, topk+extra_topk] attn_weight *= t.sm_scale attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf") lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q] attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) output = attn_weight @ gathered_kv[..., :p.d_v] # [t.b*t.s_q, t.h_q, t.dv] output = output.view(b, p.s_q, p.h_q, p.d_v) lse = lse.view(b, p.s_q, p.h_q) # Attention sink if t.attn_sink is not None: output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1) # Correct for q tokens which has no attendable k lonely_q_mask = (lse == float("-inf")) output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0 lse[lonely_q_mask] = float("+inf") return output.to(torch.bfloat16), lse.transpose(1, 2) ================================================ FILE: tests/test_flash_mla_dense_decoding.py ================================================ import argparse import math import random import dataclasses from typing import Tuple import torch import kernelkit as kk import flash_mla @dataclasses.dataclass class TestParam: b: int # Batch size s_q: int # Number of queries for one request s_k: int # Seq len, or mean seq len if varlen == True is_varlen: bool is_causal: bool test_performance: bool = True have_zero_seqlen_k: bool = False block_size: int = 64 h_q: int = 128 # Number of q heads h_kv: int = 1 # Number of kv heads d: int = 576 # Q/K head dim (= dv + RoPE dim) dv: int = 512 # V head dim seed: int = 0 def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Generate test data from a given configuration Return: [cache_seqlens, q, block_table, blocked_k] Pay attention: This function changes the random seed """ random.seed(t.seed) torch.manual_seed(t.seed) torch.cuda.manual_seed(t.seed) torch.backends.cudnn.deterministic = True assert t.h_q % t.h_kv == 0 cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu') if t.is_varlen: for i in range(t.b): cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q) if t.have_zero_seqlen_k: zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 cache_seqlens_cpu[zeros_mask] = 0 max_seqlen = int(cache_seqlens_cpu.max().item()) max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256 cache_seqlens = cache_seqlens_cpu.cuda() q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10 q.clamp_(min=-1.0, max=1.0) block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1) blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 blocked_k.clamp_(min=-1.0, max=1.0) for i in range(t.b): cur_len = int(cache_seqlens_cpu[i].item()) cur_num_blocks = kk.cdiv(cur_len, t.block_size) blocked_k[block_table[i][cur_num_blocks:]] = float("nan") if cur_len % t.block_size != 0: blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") block_table[i][cur_num_blocks:] = 2147480000 return cache_seqlens, q, block_table, blocked_k def reference_torch( cache_seqlens: torch.Tensor, # [batch_size] block_table: torch.Tensor, # [batch_size, ?] q: torch.Tensor, # [batch_size, s_q, h_q, d] blocked_k: torch.Tensor, # [?, block_size, h_kv, d] dv: int, is_causal: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ A reference implementation in PyTorch """ def scaled_dot_product_attention( batch_idx: int, query: torch.Tensor, # [h_q, s_q, d] kv: torch.Tensor, # [h_kv, s_k, d] dv: int, is_causal, ) -> Tuple[torch.Tensor, torch.Tensor]: h_q = query.size(0) h_kv = kv.size(0) s_q = query.shape[-2] s_k = kv.shape[-2] query = query.float() kv = kv.float() if h_kv != 1: kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv[kv != kv] = 0.0 attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] if is_causal and query.size(1) > 1: mask = torch.ones(s_q, s_k, dtype=torch.bool) if is_causal: mask = mask.tril(diagonal=s_k - s_q) attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_weight += attn_bias.to(q.dtype) attn_weight /= math.sqrt(query.size(-1)) lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] # Correct for q tokens which has no attendable k lonely_q_mask = (lse == float("-inf")) output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 lse[lonely_q_mask] = float("+inf") return output, lse b, s_q, h_q, d = q.size() block_size = blocked_k.size(1) h_kv = blocked_k.size(2) cache_seqlens_cpu = cache_seqlens.cpu() out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): cur_len = int(cache_seqlens_cpu[i].item()) cur_num_blocks = kk.cdiv(cur_len, block_size) cur_block_indices = block_table[i][0: cur_num_blocks] cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] cur_out, cur_lse = scaled_dot_product_attention( i, q[i].transpose(0, 1), cur_kv.transpose(0, 1), dv, is_causal ) out_ref[i] = cur_out.transpose(0, 1) lse_ref[i] = cur_lse out_ref = out_ref.to(q.dtype) return out_ref, lse_ref @torch.inference_mode() def test_flash_mla(t: TestParam): print('-------------------------------') print(f"Running on {t}...") # Generating test data torch.cuda.synchronize() cache_seqlens, q, block_table, blocked_k, = generate_test_data(t) tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() def run_flash_mla(): return flash_mla.flash_mla_with_kvcache( q, blocked_k, block_table, cache_seqlens, t.dv, tile_scheduler_metadata, num_splits, causal=t.is_causal ) out_ans, lse_ans = run_flash_mla() out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal) is_correct = True is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) assert is_correct if t.test_performance: time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel") mean_attended_seqlens = cache_seqlens.float().mean().item() compute_volume_flop = t.b * t.h_q * t.s_q * sum([ 2 * t.d * mean_attended_seqlens, # Q * K^T 2 * mean_attended_seqlens * t.dv, # attention * V ]) q_elem_size = torch.bfloat16.itemsize kv_token_size = t.d * torch.bfloat16.itemsize memory_volume_B = t.b * sum([ t.s_q * t.h_q * (t.d * q_elem_size), # Q mean_attended_seqlens * t.h_kv * kv_token_size, # K/V t.s_q * t.h_q * (t.dv * q_elem_size), # Output ]) achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_gBps = memory_volume_B / time_usage / 1e9 print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") def main(torch_dtype): device = torch.device("cuda:0") torch.set_default_dtype(torch_dtype) torch.set_default_device(device) torch.cuda.set_device(device) cc_major, cc_minor = torch.cuda.get_device_capability() assert cc_major == 9, "Dense MLA decoding is only supported on sm90 (Hopper) currently." correctness_cases = [ TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv) for b in [1, 2, 6, 64] for s_q in [1, 2, 4] for s_k in [20, 140, 4096] for h_q in [1, 3, 9, 63, 64, 126, 128] for h_kv in [1, 2, 3, 8] for is_varlen in [False, True] for is_causal in [False, True] if h_q % h_kv == 0 ] corner_cases = [ # Cases where some kv cache have zero length TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv) for h_q in [1, 3, 9, 63, 64, 126, 128] for h_kv in [1, 2, 3, 8] for is_causal in [False, True] if h_q % h_kv == 0 ] performance_cases = [ TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True) for is_causal in [False, True] for s_q in [1, 2] for s_k in [4096, 8192, 16384, 32768] ] testcases = correctness_cases + corner_cases + performance_cases for testcase in testcases: test_flash_mla(testcase) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--dtype", type=str, choices=["bf16", "fp16"], default="bf16", help="Data type to use for testing (bf16 or fp16)", ) args = parser.parse_args() torch_dtype = torch.bfloat16 if args.dtype == "fp16": torch_dtype = torch.float16 main(torch_dtype) ================================================ FILE: tests/test_flash_mla_sparse_decoding.py ================================================ import time import dataclasses from typing import Tuple, List, Dict, Optional import copy import rich.console import rich.table import torch import kernelkit as kk import flash_mla import lib from lib import TestParam from lib import RawTestParamForDecode as RawTestParam import ref """ Generate testcase for unit test """ def gen_testcase() -> List[RawTestParam]: correctness_cases = [] corner_cases = [] for d_qk in [576, 512]: for have_extra_k in ([False, True] if d_qk == 512 else [False]): for have_extra_topk_len in ([False, True] if have_extra_k else [False]): for have_topk_len in ([False, True] if d_qk == 512 else [False]): for h_q in [64, 128]: cur_correctness_cases = [ RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, have_topk_length=have_topk_len, enable_attn_sink=True, extra_s_k=extra_s_k, extra_topk=extra_topk, block_size=block_size, extra_block_size=extra_block_size, have_extra_topk_length=have_extra_topk_len, d_qk=d_qk, check_correctness=True, num_runs=0) for (s_k, topk, block_size) in [ (512, 64, 2), (512, 64, 64), (512, 64, 69), (1024, 576, 2), (1024, 576, 61), (2046, 2048, 2), (2046, 2048, 64), (2046, 2048, 576) ] for (extra_s_k, extra_topk, extra_block_size) in ([ (512, 64, 2), (512, 64, 64), (512, 64, 69), (1024, 576, 2), (1024, 576, 61), (2046, 2048, 2), (2046, 2048, 64), (2046, 2048, 576) ] if have_extra_k else [(None, None, None)]) for b in [4, 74, 321] for s_q in [1, 3] for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) ] correctness_cases.extend(cur_correctness_cases) cur_corner_cases = [ RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, is_all_indices_invalid=is_all_indices_invalid, have_zero_seqlen_k=have_zero_seqlen_k, have_topk_length=have_topk_len, enable_attn_sink=enable_attn_sink, extra_s_k=extra_s_k, extra_topk=extra_topk, block_size=block_size, extra_block_size=extra_block_size, have_extra_topk_length=have_extra_topk_len, d_qk=d_qk, check_correctness=True, num_runs=0, ) for (s_k, topk, block_size) in [ (512, 64, 61), (650, 576, 53), ] for (extra_s_k, extra_topk, extra_block_size) in ([ (512, 64, 61), (650, 576, 53), ] if have_extra_k else [(None, None, None)]) for b in [4, 74, 321] for s_q in [3] for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) for is_all_indices_invalid in [True, False] for have_zero_seqlen_k in [True, False] for enable_attn_sink in [True, False] if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink) ] corner_cases.extend(cur_corner_cases) base_and_bszs = [ # V3.2 (RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]), # MODEL1 CONFIG1 (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), # MODEL1 CONFIG2 (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), # MODEL1 CONFIG3 (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), # MODEL1 CONFIG4 (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), ] performance_cases = [ # Production cases dataclasses.replace(base, b=b) for base, bszs in base_and_bszs for b in bszs ] + [ # Peak perf cases RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk) for h_q in [64, 128] for d_qk in [512, 576] ] return correctness_cases + corner_cases + performance_cases @dataclasses.dataclass class Result: is_correct: bool compute_memory_ratio: float time_usage_per_us: float splitkv_time_usage_us: float combine_time_usage_us: float achieved_tflops: float achieved_gBps: float _counter = kk.Counter() @torch.inference_mode() def test_flash_mla(p: TestParam) -> Result: if p.seed == -1: global _counter p.seed = _counter.next() assert p.decode print("================") print(f"Running on {p}") torch.cuda.empty_cache() t = lib.generate_testcase_for_decode(p) tile_scheduler_metadata, _ = flash_mla.get_mla_metadata() def run_decode(): return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None) # We first run the kernel once to generate output data for the correctness test # We must do this first, otherwise when allocating tensors for storing answers, # it may re-use memory that contains the correct answer, leading to false positives if p.check_correctness: torch.cuda.synchronize() out_ans, lse_ans = run_decode() torch.cuda.synchronize() # torch.set_printoptions(profile='full') # print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7]) # We run the performance test before generating the answer for the correctness test to avoid interference performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) if p.num_runs == 0: performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) else: result = kk.bench_kineto(run_decode, p.num_runs) splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" combine_kernel_name = "flash_fwd_mla_combine_kernel" # Get individual kernel time usages kernel_time_usages_us: Dict[str, Optional[float]] = {} def pick_kernel_time_usage(kernel_name: str): t = [kernel_name in s for s in result.get_kernel_names()] if any(t): assert sum(t) == 1 kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6 else: kernel_time_usages_us[kernel_name] = None pick_kernel_time_usage(splitkv_kernel_name) pick_kernel_time_usage(combine_kernel_name) # Get E2E time usages def have_kernel(name: str): return kernel_time_usages_us[name] is not None if kk.is_using_profiling_tools(): e2e_time_usage_us = 1e6 else: assert have_kernel(splitkv_kernel_name) if have_kernel(combine_kernel_name): e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6 else: e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name] assert e2e_time_usage_us is not None flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t) e2e_time_usage_s = e2e_time_usage_us / 1e6 theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12 achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9 def print_kernel_time_usage(name: str, short_name: str): if kernel_time_usages_us[name] is not None: print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us') print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}') print(f'Time (per): {e2e_time_usage_us:.1f} us') print_kernel_time_usage(splitkv_kernel_name, "Splitkv") print_kernel_time_usage(combine_kernel_name, "Combine") print(f'TFlops: {achieved_tflops:.1f}') print(f'GB/s: {achieved_gBps:.0f}') performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps) is_correct = True if p.check_correctness: torch.cuda.synchronize() with torch.profiler.record_function("reference_flash_mla"): out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6) is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) is_correct &= is_out_correct and is_lse_correct performance_result.is_correct = is_correct return performance_result def main(): dtype = torch.bfloat16 device = torch.device("cuda:0") torch.set_default_dtype(dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.set_float32_matmul_precision('high') torch.set_num_threads(32) raw_testcases = gen_testcase() testcases = [t.to_test_param() for t in raw_testcases] print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}") is_no_cooldown = lib.is_no_cooldown() num_testcases_len = len(str(len(testcases))) failed_cases = [] results: List[Tuple[TestParam, Result]] = [] for testcase_idx, testcase in enumerate(testcases): if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown: time.sleep(0.3) # Cooldown print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='') result = test_flash_mla(testcase) results.append((testcase, result)) if not result.is_correct: failed_cases.append(testcase) import sys sys.exit(1) console = rich.console.Console(width=120) table = rich.table.Table(show_header=True, header_style="bold cyan") table.add_column("topk") table.add_column("Bsz") table.add_column("h_q&k") table.add_column("sq") table.add_column("sk") table.add_column("d_qk") table.add_column("Feats") table.add_column("C/M") table.add_column("TFlops") table.add_column("GBps") table.add_column("us") table.add_column(" ") for testcase, result in results: assert testcase.decode topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}" table.add_row( topk_str, str(testcase.decode.b), f"{testcase.h_q:3d} {testcase.h_kv}", str(testcase.s_q), str(testcase.s_kv), str(testcase.d_qk), " V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length], f"{result.compute_memory_ratio:3.0f}", f"{result.achieved_tflops:3.0f}", f"{result.achieved_gBps:4.0f}", f"{result.time_usage_per_us:4.1f}", "" if result.is_correct else "X" ) console.print(table) def geomean(l) -> float: import numpy return numpy.exp(numpy.mean(numpy.log(l))) num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True) num_correctness_cases = sum([1 for t in testcases if t.check_correctness]) if num_correct_testcases == num_correctness_cases: print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") else: print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") for t in failed_cases: print(f"\t{t},") valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1] if len(valid_achieved_tflops) > 0: achieved_tflops_geomean = geomean(valid_achieved_tflops) # > 0.1 to prune out correctness cases print(f"TFlops geomean: {achieved_tflops_geomean:.1f}") if __name__ == "__main__": main() ================================================ FILE: tests/test_flash_mla_sparse_prefill.py ================================================ import time import sys import torch import kernelkit as kk from lib import TestParam import lib import ref _counter = kk.Counter() @torch.inference_mode() def run_test(p: TestParam) -> bool: if p.seed == -1: global _counter p.seed = _counter.next() print("================") print(f"Running on {p}") torch.cuda.empty_cache() t = lib.generate_testcase(p) torch.cuda.synchronize() def run_prefill(): return lib.run_flash_mla_sparse_fwd(p, t, False) prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill() torch.cuda.synchronize() if p.num_runs > 0: flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd") prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps") if p.check_correctness: torch.cuda.synchronize() ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t) ref_lse[ref_lse == float("-inf")] = float("+inf") torch.cuda.synchronize() is_correct = True is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) return is_correct else: return True if __name__ == '__main__': device = torch.device("cuda:0") torch.set_default_dtype(torch.bfloat16) torch.set_default_device(device) torch.cuda.set_device(device) torch.set_float32_matmul_precision('high') correctness_cases = [ # Regular shapes TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk) for d_qk in [512, 576] for h_q in [ 128, 64 ] for s_kv, topk in [ # Regular shapes (128, 128), (256, 256), (512, 512), # Irregular shapes (592, 128), (1840, 256), (1592, 384), (1521, 512), # Irregular shapes with OOB TopK (95, 128), (153, 256), (114, 384), ] for s_q in [ 1, 62, 213 ] ] correctness_cases_with_features = [ TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk) for d_qk in [512, 576] for h_q in [ 128, 64 ] for s_kv, topk in [ (592, 128), (1840, 256), (1592, 384), (1521, 512), (95, 128), (153, 256), (114, 384), ] for s_q in [62, 213] for have_sink_lse in [False, True] for have_attn_sink in [False, True] for have_topk_length in [False, True] ] corner_cases = [ TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) for d_qk in [512, 576] for h_q in [ 128, 64 ] for s_q, s_kv, topk in [ (1, 128, 128), (1, 256, 256), (1234, 4321, 4096), (4096, 2048, 2048) ] ] + [ # In these cases, some blocks may not have any valid topk indices TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) for d_qk in [512, 576] for h_q in [ 128, 64 ] for s_kv, topk in [ (32, 2048), (64, 8192) ] for s_q in [1, 1024] ] + [ # In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) for d_qk in [512, 576] for h_q in [ 128, 64 ] ] performance_case_templates = [ # V3.2 (576, 128, 2048, [8192, 32768, 65536, 98304, 131072]), # MODEL1 CONFIG1 (512, 64, 512, [8192, 32768, 49152, 65536]), # MODEL1 CONFIG2 (512, 128, 1024, [8192, 32768, 49152, 65536]), ] performance_cases = [ TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True) for (d_qk, h_q, topk, s_kv_list) in performance_case_templates for s_q in [4096] for s_kv in s_kv_list ] testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases is_no_cooldown = lib.is_no_cooldown() failed_cases = [] for test in testcases: if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown: time.sleep(0.3) is_correct = run_test(test) if not is_correct: failed_cases.append(test) if len(failed_cases) > 0: print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") for case in failed_cases: print(f" {case}") sys.exit(1) else: print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") ================================================ FILE: tests/test_fmha_sm100.py ================================================ import random import torch from torch.utils.checkpoint import checkpoint import triton from flash_mla import flash_attn_varlen_func from kernelkit import check_is_allclose def get_window_size(causal, window): if window > 0: window_size = (window - 1, 0) if causal else (window - 1, window - 1) else: window_size = (-1, -1) return window_size def get_attn_bias(s_q, s_k, causal, window): attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32) if causal: temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) if window > 0: temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window) attn_bias.masked_fill_(temp_mask, float("-inf")) temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) return attn_bias def sdpa(query, key, value, attn_bias, softmax_scale=None): query = query.float().transpose(-3, -2) key = key.float().transpose(-3, -2) value = value.float().transpose(-3, -2) key = key.repeat_interleave(h // h_k, dim=-3) value = value.repeat_interleave(h // h_k, dim=-3) if softmax_scale is None: softmax_scale = query.shape[-1] ** (-0.5) attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale attn_weight += attn_bias lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) return attn_weight.to(query.dtype) @ value, lse def sdpa_checkpoint(*args, **kwargs): return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, check_correctness: bool = True): print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}, {has_bwd=}, {check_correctness=}") torch.manual_seed(0) random.seed(0) seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32) seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1) for i in range(b): seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item()) cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32) cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32) total_q = seqlens_q.sum().item() total_k = seqlens_k.sum().item() max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window) == 0).sum().item() for i in range(b)]) # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") q = torch.randn(total_q, h, d) / 10 k = torch.randn(total_k, h_k, d) / 10 v = torch.randn(total_k, h_k, dv) / 10 grad_out = torch.randn(total_q, h, dv) / 10 softmax_scale = (d + 100) ** (-0.5) q1 = q.clone().requires_grad_() k1 = k.clone().requires_grad_() v1 = v.clone().requires_grad_() if check_correctness: q2 = q.clone().requires_grad_() k2 = k.clone().requires_grad_() v2 = v.clone().requires_grad_() def flash_attn(): q1.grad = k1.grad = v1.grad = None kwargs = {} if causal: kwargs["causal"] = causal if window != 0: kwargs["window_size"] = get_window_size(causal, window) return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs) def torch_attn(): q2.grad = k2.grad = v2.grad = None out = [] lse = [] for i in range(b): OUT, LSE = sdpa_checkpoint( q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()], k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), softmax_scale=softmax_scale, ) out.append(OUT.transpose(-3, -2)) lse.append(LSE.transpose(-2, -1)) out = torch.cat(out) lse = torch.cat(lse) return out, lse out_flash, lse_flash = flash_attn() if has_bwd: out_flash.backward(grad_out, retain_graph=True) _dq1 = q1.grad.clone() dk1 = k1.grad.clone() dv1 = v1.grad.clone() if check_correctness: out_torch, lse_torch = torch_attn() assert check_is_allclose("out", out_flash.float(), out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) assert check_is_allclose("lse", lse_flash.float(), lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) if has_bwd: out_torch.backward(grad_out, retain_graph=True) assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) def forward(): return flash_attn() def backward(): q1.grad = k1.grad = v1.grad = None out_flash.backward(grad_out, retain_graph=True) for _ in range(5): out, lse = forward() assert torch.equal(out, out_flash), "out deterministic check failed!" assert torch.equal(lse, lse_flash), "lse deterministic check failed!" if has_bwd: backward() # assert torch.equal(q1.grad, dq1), "dq deterministic check failed!" assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" def timer(func, name): t = triton.testing.do_bench(func, warmup=2, rep=3) FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}") return t timer(forward, "fwd") if has_bwd: timer(backward, "bwd") if __name__ == "__main__": dtype = torch.bfloat16 torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) torch.cuda.set_device(device) torch.set_float32_matmul_precision("high") b = 2 window = 0 has_bwd = False for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: for varlen in [False, True]: for (h, h_k) in [(128, 128), (32, 4)]: if h != h_k: has_bwd = False else: has_bwd = True for (d, dv) in [(128, 128), (192, 128)]: for causal in [False, True]: skip_correctness_check = mean_sq == 8192 and mean_sk == 8192 and h == 128 test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, not skip_correctness_check)