[
  {
    "path": ".gitignore",
    "content": "build\n*.so\n*.egg-info/\n__pycache__/\ndist/\n*perf.csv\n*.png\n/.vscode\ncompile_commands.json\n.cache\n/dev\n/.clangd\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"csrc/cutlass\"]\n\tpath = csrc/cutlass\n\turl = https://github.com/NVIDIA/cutlass.git\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 DeepSeek\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# FlashMLA\n\n## Introduction\n\nFlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations:\n\n**Sparse Attention Kernels**\n\n*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).*\n\n- Token-level sparse attention for the prefill stage\n- Token-level sparse attention for the decoding stage, with FP8 KV cache\n\n**Dense Attention Kernels**\n\n- Dense attention for the prefill stage\n- Dense attention for the decoding stage\n\n## News\n\n- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md).\n- **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100!\n- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).\n- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀\n\n## Performance\n\n#### Test & benchmark MLA decoding (Sparse & Dense):\n\n```bash\npython tests/test_flash_mla_dense_decoding.py\npython tests/test_flash_mla_sparse_decoding.py\n```\n\nThe dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).\n\n#### Test & benchmark MHA prefill (Dense):\n\n```bash\npython tests/test_fmha_sm100.py\n```\n\nIt achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.\n\n#### Test & benchmark MLA prefill (Sparse):\n\n```bash\npython tests/test_flash_mla_sparse_prefill.py\n```\n\nIt achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.\n\n## Requirements\n\n- SM90 / SM100 (See the support matrix below)\n- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)\n- PyTorch 2.0 and above\n\nSupport matrix:\n\n| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |\n| :---: | :---: | :---: | :---: |\n| Dense Decoding | SM90 | MQA | BF16 |\n| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] |\n| Dense Prefill | SM100 | MHA |  |\n| Sparse Prefill | SM90 & SM100 | MQA |  |\n\n[1]: For more details on using FP8 KV cache, see documents below.\n\n[2]: Here \"MLA Mode\" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` =  576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).\n\n## Installation\n\n```bash\ngit clone https://github.com/deepseek-ai/FlashMLA.git flash-mla\ncd flash-mla\ngit submodule update --init --recursive\npip install -v .\n```\n\n## Usage\n\n### MLA Decoding\n\nTo use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:\n\n```python\nfrom flash_mla import get_mla_metadata, flash_mla_with_kvcache\n\ntile_scheduler_metadata, num_splits = get_mla_metadata(\n    cache_seqlens,\n    s_q * h_q // h_kv,\n    h_kv,\n    h_q,\n    is_fp8,\n    topk,\n)\n\nfor i in range(num_layers):\n    ...\n    o_i, lse_i = flash_mla_with_kvcache(\n        q_i, kvcache_i, block_table, cache_seqlens, dv,\n        tile_scheduler_metadata, num_splits,\n        is_causal, is_fp8_kvcache, indices,\n    )\n    ...\n```\n\nWhere\n\n- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.\n- `h_kv` is the number of key-value heads.\n- `h_q` is the number of query heads.\n\n**FP8 KV Cache:**\nIf `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the \"FP8 with scale\" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.\n\nIn the \"FP8 with scale\" format, each token's KV cache is 656 Bytes, structured as:\n-   **First 512 bytes:** The \"quantized NoPE\" part, containing 512 `float8_e4m3` values.\n-   **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.\n-   **Last 128 bytes:** The \"RoPE\" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.\n\nSee `tests/quant.py` for quantization and dequantization details.\n\n**Sparse Attention (`indices` tensor):**\nThe `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.\n\n-   **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`.\n-   **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter.\n-   **Invalid entries:** Set invalid indices to `-1`.\n\n**Return Values:**\nThe kernel returns `(out, lse)`, where:\n-   `out` is the attention result.\n-   `lse` is the log-sum-exp value of the attention scores for each query head.\n\nSee `tests/test_flash_mla_decoding.py` for a complete example.\n\n### Sparse MLA Prefill\n\nFor the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters:\n-   `q`: Query tensor of shape `[s_q, h_q, d_qk]`\n-   `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]`\n-   `indices`: Indices tensor of shape `[s_q, h_kv, topk]`\n-   `sm_scale`: A scalar value\n\n**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing.\n\n**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`.\n\n**Return Values and Equivalent PyTorch Code:**\nThe kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations:\n\n```python\nQ: [s_q, h_q, d_qk], bfloat16\nkv: [s_kv, h_kv, d_qk], bfloat16\nindices: [s_q, h_kv, topk], int32\n\nkv = kv.squeeze(1)  # [s_kv, d_qk], h_kv must be 1\nindices = indices.squeeze(1)    # [s_q, topk]\nfocused_kv = kv[indices]    # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].\n\nP = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e)    # [s_q, h_q, topk]\nmax_logits = P.max(dim=-1) # [s_q, h_q]\nlse = log2sumexp2(P, dim=-1, base=2)   # [s_q, h_q]，\"log2sumexp2\" means that the exponentiation and logarithm are base-2\nS = exp2(P - lse)      # [s_q, h_q, topk]\nout = S @ focused_kv  # [s_q, h_q, d_qk]\n\nreturn (out, max_logits, lse)\n```\n\nSee `tests/test_flash_mla_prefill.py` for a complete example.\n\n### Dense MHA Prefill\n\nThis kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:\n-   `flash_attn_varlen_func`\n-   `flash_attn_varlen_qkvpacked_func`\n-   `flash_attn_varlen_kvpacked_func`\n\nThe usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example.\n\n## Acknowledgement\n\nFlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.\n\n## Community Support\n\n### MetaX\nFor MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).\n\nThe corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)\n\n\n### Moore Threads\nFor the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).\n\nThe corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).\n\n\n### Hygon DCU\nFor the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).\n\nThe corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).\n\n\n### Intellifusion\nFor the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).\n\nThe corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).\n\n\n### Iluvatar Corex\nFor Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).\n\nThe corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)\n\n\n### AMD Instinct\nFor AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html).\n\nThe corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py)\n\n## Citation\n\n```bibtex\n@misc{flashmla2025,\n      title={FlashMLA: Efficient Multi-head Latent Attention Kernels},\n      author={Jiashi Li, Shengyu Liu},\n      year={2025},\n      publisher = {GitHub},\n      howpublished = {\\url{https://github.com/deepseek-ai/FlashMLA}},\n}\n```\n"
  },
  {
    "path": "benchmark/bench_flash_mla.py",
    "content": "# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a\nimport argparse\nimport math\nimport random\n\nimport flashinfer\nimport torch\nimport triton\nimport triton.language as tl\n\n# pip install flashinfer-python\nfrom flash_mla import flash_mla_with_kvcache, get_mla_metadata\n\n\ndef scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):\n    query = query.float()\n    key = key.float()\n    value = value.float()\n    key = key.repeat_interleave(h_q // h_kv, dim=0)\n    value = value.repeat_interleave(h_q // h_kv, dim=0)\n    attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))\n    if is_causal:\n        s_q = query.shape[-2]\n        s_k = key.shape[-2]\n        attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)\n        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n        attn_bias.to(query.dtype)\n        attn_weight += attn_bias\n    lse = attn_weight.logsumexp(dim=-1)\n    attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)\n    return attn_weight @ value, lse\n\n\n@torch.inference_mode()\ndef run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    for i in range(b):\n        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float(\"nan\")\n    blocked_v = blocked_k[..., :dv]\n\n    def ref_mla():\n        out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)\n        lse = torch.empty(b, h_q, s_q, dtype=torch.float32)\n        for i in range(b):\n            begin = i * max_seqlen_pad\n            end = begin + cache_seqlens[i]\n            O, LSE = scaled_dot_product_attention(\n                q[i].transpose(0, 1),\n                blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),\n                blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),\n                h_q, h_kv,\n                is_causal=causal,\n            )\n            out[i] = O.transpose(0, 1)\n            lse[i] = LSE\n        return out, lse\n\n    out_torch, lse_torch = ref_mla()\n    t = triton.testing.do_bench(ref_mla)\n    return out_torch, lse_torch, t\n\n@torch.inference_mode()\ndef run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    for i in range(b):\n        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float(\"nan\")\n    blocked_v = blocked_k[..., :dv]\n\n    tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)\n\n    def flash_mla():\n        return flash_mla_with_kvcache(\n            q, blocked_k, block_table, cache_seqlens, dv,\n            tile_scheduler_metadata, num_splits, causal=causal,\n        )\n\n    out_flash, lse_flash = flash_mla()\n    t = triton.testing.do_bench(flash_mla)\n    return out_flash, lse_flash, t\n\n\n@torch.inference_mode()\ndef run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    \n    for i in range(b):\n        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float(\"nan\")\n\n    assert d > dv, \"mla with rope dim should be larger than no rope dim\"\n    q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()\n    blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()\n    \n    \n    kv_indptr = [0]\n    kv_indices = []\n    for i in range(b):\n        seq_len = cache_seqlens[i]\n        assert seq_len > 0\n        num_blocks = (seq_len + block_size - 1) // block_size\n        kv_indices.extend(block_table[i, :num_blocks])\n        kv_indptr.append(kv_indptr[-1] + num_blocks)\n    for seq_len in cache_seqlens[1:]:\n        kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])\n        \n    q_indptr = torch.arange(0, b + 1).int() * s_q\n    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)\n    kv_indices = torch.tensor(kv_indices, dtype=torch.int32)\n\n    mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n        torch.empty(128 * 1024 * 1024, dtype=torch.int8),\n        backend=\"fa3\"\n    )\n    mla_wrapper.plan(\n        q_indptr,\n        kv_indptr,\n        kv_indices,\n        cache_seqlens,\n        h_q,\n        dv,\n        d-dv,\n        block_size,\n        causal,\n        1 / math.sqrt(d),\n        q.dtype,\n        blocked_k.dtype,\n    )\n\n    def flash_infer():\n        output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)\n        return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)\n\n    out_flash, lse_flash = flash_infer()\n    t = triton.testing.do_bench(flash_infer)\n    return out_flash, lse_flash, t\n\n\n@triton.jit\ndef _mla_attn_kernel(\n    Q_nope,\n    Q_pe,\n    Kv_c_cache,\n    K_pe_cache,\n    Req_to_tokens,\n    B_seq_len,\n    O,\n    sm_scale,\n    stride_q_nope_bs,\n    stride_q_nope_h,\n    stride_q_pe_bs,\n    stride_q_pe_h,\n    stride_kv_c_bs,\n    stride_k_pe_bs,\n    stride_req_to_tokens_bs,\n    stride_o_b,\n    stride_o_h,\n    stride_o_s,\n    BLOCK_H: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    NUM_KV_SPLITS: tl.constexpr,\n    PAGE_SIZE: tl.constexpr,\n    HEAD_DIM_CKV: tl.constexpr,\n    HEAD_DIM_KPE: tl.constexpr,\n):\n    cur_batch = tl.program_id(1)\n    cur_head_id = tl.program_id(0)\n    split_kv_id = tl.program_id(2)\n\n    cur_batch_seq_len = tl.load(B_seq_len + cur_batch)\n\n    offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)\n    cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)\n    offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]\n    q_nope = tl.load(Q_nope + offs_q_nope)\n\n    offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)\n    offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]\n    q_pe = tl.load(Q_pe + offs_q_pe)\n\n    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float(\"inf\")\n    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)\n\n    kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\n    split_kv_start = kv_len_per_split * split_kv_id\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n    for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\n        offs_n = start_n + tl.arange(0, BLOCK_N)\n        kv_page_number = tl.load(\n            Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,\n            mask=offs_n < split_kv_end,\n            other=0,\n        )\n        kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE\n        offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]\n        k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)\n\n        qk = tl.dot(q_nope, k_c.to(q_nope.dtype))\n\n        offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]\n        k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)\n\n        qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))\n        qk *= sm_scale\n\n        qk = tl.where(offs_n[None, :] < split_kv_end, qk, float(\"-inf\"))\n\n        v_c = tl.trans(k_c)\n\n        n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n        re_scale = tl.exp(e_max - n_e_max)\n        p = tl.exp(qk - n_e_max[:, None])\n        acc *= re_scale[:, None]\n        acc += tl.dot(p.to(v_c.dtype), v_c)\n\n        e_sum = e_sum * re_scale + tl.sum(p, 1)\n        e_max = n_e_max\n    offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]\n    tl.store(O + offs_o, acc / e_sum[:, None])\n    offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV\n    tl.store(O + offs_o_1, e_max + tl.log(e_sum))\n\n\ndef _mla_attn(\n    q_nope,\n    q_pe,\n    kv_c_cache,\n    k_pe_cache,\n    attn_logits,\n    req_to_tokens,\n    b_seq_len,\n    num_kv_splits,\n    sm_scale,\n    page_size,\n):\n    batch_size, head_num = q_nope.shape[0], q_nope.shape[1]\n    head_dim_ckv = q_nope.shape[-1]\n    head_dim_kpe = q_pe.shape[-1]\n\n    BLOCK_H = 16\n    BLOCK_N = 64\n    grid = (\n        triton.cdiv(head_num, BLOCK_H),\n        batch_size,\n        num_kv_splits,\n    )\n    _mla_attn_kernel[grid](\n        q_nope,\n        q_pe,\n        kv_c_cache,\n        k_pe_cache,\n        req_to_tokens,\n        b_seq_len,\n        attn_logits,\n        sm_scale,\n        # stride\n        q_nope.stride(0),\n        q_nope.stride(1),\n        q_pe.stride(0),\n        q_pe.stride(1),\n        kv_c_cache.stride(-2),\n        k_pe_cache.stride(-2),\n        req_to_tokens.stride(0),\n        attn_logits.stride(0),\n        attn_logits.stride(1),\n        attn_logits.stride(2),\n        BLOCK_H=BLOCK_H,\n        BLOCK_N=BLOCK_N, \n        NUM_KV_SPLITS=num_kv_splits,\n        PAGE_SIZE=page_size,\n        HEAD_DIM_CKV=head_dim_ckv,\n        HEAD_DIM_KPE=head_dim_kpe,\n    )\n\n@triton.jit\ndef _mla_softmax_reducev_kernel(\n    Logits,\n    B_seq_len,\n    O,\n    stride_l_b,\n    stride_l_h,\n    stride_l_s,\n    stride_o_b,\n    stride_o_h,\n    NUM_KV_SPLITS: tl.constexpr,\n    HEAD_DIM_CKV: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    cur_batch_seq_len = tl.load(B_seq_len + cur_batch)\n\n    offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)\n\n    e_sum = 0.0\n    e_max = -float(\"inf\")\n    acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)\n\n    offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv\n    offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV\n\n    for split_kv_id in range(0, NUM_KV_SPLITS):\n        kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\n        split_kv_start = kv_len_per_split * split_kv_id\n        split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)\n\n        if split_kv_end > split_kv_start:\n            logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)\n            logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)\n\n            n_e_max = tl.maximum(logits_1, e_max)\n            old_scale = tl.exp(e_max - n_e_max)\n            acc *= old_scale\n            exp_logic = tl.exp(logits_1 - n_e_max)\n            acc += exp_logic * logits\n\n            e_sum = e_sum * old_scale + exp_logic\n            e_max = n_e_max\n    \n    tl.store(\n        O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,\n        acc / e_sum,\n    )\n\n\ndef _mla_softmax_reducev(\n    logits,\n    o,\n    b_seq_len,\n    num_kv_splits,\n):\n    batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]\n    grid = (batch_size, head_num)\n    _mla_softmax_reducev_kernel[grid](\n        logits,\n        b_seq_len,\n        o,\n        logits.stride(0),\n        logits.stride(1),\n        logits.stride(2),\n        o.stride(0),\n        o.stride(1),\n        NUM_KV_SPLITS=num_kv_splits,\n        HEAD_DIM_CKV=head_dim_ckv,\n        num_warps=4,\n        num_stages=2,\n    )\n\ndef mla_decode_triton(\n    q_nope,\n    q_pe,\n    kv_c_cache,\n    k_pe_cache,\n    o,\n    req_to_tokens,\n    b_seq_len,\n    attn_logits,\n    num_kv_splits,\n    sm_scale,\n    page_size,\n):\n    assert num_kv_splits == attn_logits.shape[2]\n    _mla_attn(\n        q_nope,\n        q_pe,\n        kv_c_cache,\n        k_pe_cache,\n        attn_logits,\n        req_to_tokens,\n        b_seq_len,\n        num_kv_splits,\n        sm_scale,\n        page_size,\n    )\n    _mla_softmax_reducev(\n        attn_logits,\n        o,\n        b_seq_len,\n        num_kv_splits,\n    )\n    \n\n@torch.inference_mode()\ndef run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    \n    for i in range(b):\n        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float(\"nan\")\n    blocked_v = blocked_k[..., :dv]\n    \n    assert d > dv, \"mla with rope dim should be larger than no rope dim\"\n    q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()\n    blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()\n\n    def flash_mla_triton():\n        num_kv_splits = 32\n        o = torch.empty([b * s_q, h_q, dv])\n        attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])\n        mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)\n        return o.view([b, s_q, h_q, dv])\n\n    out_flash = flash_mla_triton()\n    t = triton.testing.do_bench(flash_mla_triton)\n    return out_flash, None, t\n\n\nFUNC_TABLE = {\n    \"torch\": run_torch_mla,\n    \"flash_mla\": run_flash_mla,\n    \"flash_infer\": run_flash_infer,\n    \"flash_mla_triton\": run_flash_mla_triton,\n}\n    \ndef compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    print(f\"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}\")\n    device = torch.device(\"cuda:0\")\n    torch.set_default_dtype(dtype)\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n    torch.manual_seed(0)\n    random.seed(0)\n    assert baseline in FUNC_TABLE\n    assert target in FUNC_TABLE\n    baseline_func = FUNC_TABLE[baseline]\n    target_func = FUNC_TABLE[target]\n    \n    total_seqlens = cache_seqlens.sum().item()\n    mean_seqlens = cache_seqlens.float().mean().int().item()\n    max_seqlen = cache_seqlens.max().item()\n    max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256\n    # print(f\"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}\")\n\n    q = torch.randn(b, s_q, h_q, d)\n    block_size = 64\n    block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)\n    blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)\n    \n    out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)\n    out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)\n    \n    torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), \"out\"\n    if target not in [\"flash_infer\", \"flash_mla_triton\"] and baseline not in [\"flash_infer\", \"flash_mla_triton\"]:\n        # flash_infer has a different lse return value\n        # flash_mla_triton doesn't return lse\n        torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), \"lse\"\n\n    FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2\n    bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)\n    print(f\"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s\")\n    print(f\"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s\")\n    return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b\n\n\ndef compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):\n    print(f\"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}\")\n    torch.set_default_dtype(dtype)\n    device = torch.device(\"cuda:0\")\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n    torch.manual_seed(0)\n    random.seed(0)\n    assert target in FUNC_TABLE\n    target_func = FUNC_TABLE[target]\n    \n    total_seqlens = cache_seqlens.sum().item()\n    mean_seqlens = cache_seqlens.float().mean().int().item()\n    max_seqlen = cache_seqlens.max().item()\n    max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256\n    # print(f\"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}\")\n\n    q = torch.randn(b, s_q, h_q, d)\n    block_size = 64\n    block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)\n    blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)\n    \n    out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)\n\n    FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2\n    bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)\n    print(f\"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s\")\n    return bytes / 10 ** 6 / perf_b\n\n\navailable_targets = [\n    \"torch\",\n    \"flash_mla\",\n    \"flash_infer\",\n    \"flash_mla_triton\",\n]\n\nshape_configs = [\n    {\"b\": batch, \"s_q\": 1, \"cache_seqlens\": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device=\"cuda\"), \"h_q\": head, \"h_kv\": 1, \"d\": 512+64, \"dv\": 512, \"causal\": True, \"dtype\": torch.bfloat16}\n    for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]\n]\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--baseline\", type=str, default=\"torch\")\n    parser.add_argument(\"--target\", type=str, default=\"flash_mla\")\n    parser.add_argument(\"--all\", action=\"store_true\")\n    parser.add_argument(\"--one\", action=\"store_true\")\n    parser.add_argument(\"--compare\", action=\"store_true\")\n    args = parser.parse_args()\n    return args\n\n    \nif __name__ == \"__main__\":\n    args = get_args()\n    benchmark_type = \"all\" if args.all else f\"{args.baseline}_vs_{args.target}\" if args.compare else args.target\n    with open(f\"{benchmark_type}_perf.csv\", \"w\") as fout:\n        fout.write(\"name,batch,seqlen,head,bw\\n\")\n        for shape in shape_configs:\n            if args.all:\n                for target in available_targets:\n                    perf = compare_a(target, shape[\"b\"], shape[\"s_q\"], shape[\"cache_seqlens\"], shape[\"h_q\"], shape[\"h_kv\"], shape[\"d\"], shape[\"dv\"], shape[\"causal\"], shape[\"dtype\"])\n                    fout.write(f'{target},{shape[\"b\"]},{shape[\"cache_seqlens\"].float().mean().cpu().item():.0f},{shape[\"h_q\"]},{perf:.0f}\\n')\n            elif args.compare:\n                perfa, prefb = compare_ab(args.baseline, args.target, shape[\"b\"], shape[\"s_q\"], shape[\"cache_seqlens\"], shape[\"h_q\"], shape[\"h_kv\"], shape[\"d\"], shape[\"dv\"], shape[\"causal\"], shape[\"dtype\"])\n                fout.write(f'{args.baseline},{shape[\"b\"]},{shape[\"cache_seqlens\"].float().mean().cpu().item():.0f},{shape[\"h_q\"]},{perfa:.0f}\\n')\n                fout.write(f'{args.target},{shape[\"b\"]},{shape[\"cache_seqlens\"].float().mean().cpu().item():.0f},{shape[\"h_q\"]},{prefb:.0f}\\n')\n            elif args.one:\n                perf = compare_a(args.target, shape[\"b\"], shape[\"s_q\"], shape[\"cache_seqlens\"], shape[\"h_q\"], shape[\"h_kv\"], shape[\"d\"], shape[\"dv\"], shape[\"causal\"], shape[\"dtype\"])\n                fout.write(f'{args.target},{shape[\"b\"]},{shape[\"cache_seqlens\"].float().mean().cpu().item():.0f},{shape[\"h_q\"]},{perf:.0f}\\n')"
  },
  {
    "path": "benchmark/visualize.py",
    "content": "import argparse\n\nimport matplotlib.pyplot as plt\nimport pandas as pd\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Visualize benchmark results')\n    parser.add_argument('--file', type=str, default='all_perf.csv',\n                        help='Path to the CSV file with benchmark results (default: all_perf.csv)')\n    return parser.parse_args()\n\nargs = parse_args()\nfile_path = args.file\n\ndf = pd.read_csv(file_path)\n\nnames = df['name'].unique()\n\nfor name in names:\n    subset = df[df['name'] == name]\n    plt.plot(subset['seqlen'], subset['bw'], label=name)\n\nplt.title('bandwidth')\nplt.xlabel('seqlen')\nplt.ylabel('bw (GB/s)')\nplt.legend()\n\nplt.savefig(f'{file_path.split(\".\")[0].split(\"/\")[-1]}_bandwidth_vs_seqlen.png')"
  },
  {
    "path": "csrc/api/api.cpp",
    "content": "#include <pybind11/pybind11.h>\n\n#include \"sparse_fwd.h\"\n#include \"sparse_decode.h\"\n#include \"dense_decode.h\"\n#include \"dense_fwd.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"FlashMLA\";\n    m.def(\"sparse_decode_fwd\", &sparse_attn_decode_interface);\n    m.def(\"dense_decode_fwd\", &dense_attn_decode_interface);\n    m.def(\"sparse_prefill_fwd\", &sparse_attn_prefill_interface);\n    m.def(\"dense_prefill_fwd\", &FMHACutlassSM100FwdRun);\n    m.def(\"dense_prefill_bwd\", &FMHACutlassSM100BwdRun);\n}\n"
  },
  {
    "path": "csrc/api/common.h",
    "content": "#pragma once\n\n#include <span>\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <kerutils/supplemental/torch_tensors.h>\n\n#include <cutlass/bfloat16.h>\n\nstatic constexpr float LOG_2_E = 1.44269504f;\n\n// Instantiation for tensor.data_ptr<cutlass::bfloat16_t>()\ntemplate<>\ninline cutlass::bfloat16_t* at::TensorBase::data_ptr<cutlass::bfloat16_t>() const {\n    return reinterpret_cast<cutlass::bfloat16_t*>(this->data_ptr());\n}\n\n// A struct that holds the architecture information of the current GPU.\nstruct Arch {\n    int major;\n    int minor;\n    int num_sms;\n    cudaDeviceProp* device_prop;\n\n    Arch() {\n        device_prop = at::cuda::getCurrentDeviceProperties();\n        major = device_prop->major;\n        minor = device_prop->minor;\n        num_sms = device_prop->multiProcessorCount;\n    }\n\n    bool is_sm90a() const {\n        return major == 9 && minor == 0;\n    }\n\n    bool is_sm100f() const {\n        return major == 10;\n    }\n};\n\n// Convert int64_t stride to int32_t, with overflow check.\ninline int int64_stride_to_int(int64_t orig_stride) {\n    if (orig_stride > std::numeric_limits<int>::max()) {\n        TORCH_CHECK(false, \"[FlashMLA] Stride exceeds int32 limit: \", orig_stride);\n    }\n    return static_cast<int>(orig_stride);\n}\n\n#define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \\\n    [&] () { \\\n        if (NUM_HEADS == 128) { \\\n            static constexpr int CONSTEXPR_NAME = 128; \\\n            return __VA_ARGS__(); \\\n        } else if (NUM_HEADS == 64) { \\\n            static constexpr int CONSTEXPR_NAME = 64; \\\n            return __VA_ARGS__(); \\\n        } else { \\\n            TORCH_CHECK(false, \"Unsupported num_heads_q: \", NUM_HEADS); \\\n        } \\\n    } ();\n\n#define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \\\n[&] () { \\\n    if (HEAD_DIM == 576) { \\\n        static constexpr int CONSTEXPR_NAME = 576; \\\n        return __VA_ARGS__(); \\\n    } else if (HEAD_DIM == 512) { \\\n        static constexpr int CONSTEXPR_NAME = 512; \\\n        return __VA_ARGS__(); \\\n    } else { \\\n        TORCH_CHECK(false, \"Unsupported head_dim_qk: \", HEAD_DIM); \\\n    } \\\n} ();\n\n#define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \\\n    [&] () { \\\n        if (FLAG) { \\\n            static constexpr bool CONSTEXPR_NAME = true; \\\n            return __VA_ARGS__(); \\\n        } else { \\\n            static constexpr bool CONSTEXPR_NAME = false; \\\n            return __VA_ARGS__(); \\\n        } \\\n    } ();\n\n#define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \\\n[&] () { \\\n    if (MODEL_TYPE == ModelType::V32) { \\\n        static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \\\n        return __VA_ARGS__(); \\\n    } else if (MODEL_TYPE == ModelType::MODEL1) { \\\n        static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \\\n        return __VA_ARGS__(); \\\n    } else { \\\n        TORCH_CHECK(false, \"Unsupported model type: \", (int)MODEL_TYPE); \\\n    } \\\n} ();\n\n// The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names.\ntemplate<auto value>\nconstexpr auto get_static_enum_name(){\n    std::string_view name;\n#if __GNUC__ || __clang__\n    name = __PRETTY_FUNCTION__;\n    std::size_t start = name.find('=') + 2;\n    std::size_t end = name.size() - 1;\n    name = std::string_view{ name.data() + start, end - start };\n    start = name.find(\"::\");\n#elif _MSC_VER\n    name = __FUNCSIG__;\n    std::size_t start = name.find('<') + 1;\n    std::size_t end = name.rfind(\">(\");\n    name = std::string_view{ name.data() + start, end - start };\n    start = name.rfind(\"::\");\n#endif\n    return start == std::string_view::npos ? name : std::string_view {\n            name.data() + start + 2, name.size() - start - 2\n    };\n}\n\ntemplate<typename T, std::size_t N = 0> \nstatic constexpr std::size_t get_enum_max(){\n    constexpr T value = static_cast<T>(N);\n    if constexpr (get_static_enum_name<value>().find(\")\") == std::string_view::npos)\n        return get_enum_max<T, N + 1>();\n    else\n        return N;\n}\n\ntemplate<typename T> requires std::is_enum_v<T>\nstatic constexpr std::string get_dynamic_enum_name(T value){\n    constexpr std::size_t num = get_enum_max<T>();\n    constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){\n        return std::array<std::string_view, num>{ \n            get_static_enum_name<static_cast<T>(Is)>()... \n        };\n    }(std::make_index_sequence<num>{});\n    return (std::string)names[static_cast<std::size_t>(value)];\n}\n\n// A shortcut macro to declare supported features in an implementation class.\n#define DECLARE_SUPPORTED_FEATURES(...) \\\nprotected: \\\n    static constexpr FeatureT features[] = { __VA_ARGS__ }; \\\n    constexpr inline std::span<const FeatureT> get_supported_features() const override { \\\n        return features; \\\n    }\n\n/*\nImplBase - The base class for every implementation.\n\nEvery implementation should inherit from this class and implement the pure virtual functions, including:\n- `run_`: The function that runs the implementation.\n- `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way.\n\nThe dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`.\n*/\ntemplate<\n    typename RunArgT_,\n    typename FeatureT_\n>\nclass ImplBase {\nprotected:\n    using RunArgT = RunArgT_;\n    using FeatureT = FeatureT_;\n\n    virtual inline void run_(const RunArgT &params, const std::vector<FeatureT> &required_features) = 0;\n\n    constexpr virtual inline std::span<const FeatureT> get_supported_features() const = 0;\n\n    virtual ~ImplBase() = default;\n\npublic:\n    inline bool check_if_all_features_are_supported(const std::vector<FeatureT> &required_features) {\n        for (const auto &required_feature : required_features) {\n            bool is_supported = false;\n            for (const auto &supported_feature : get_supported_features()) {\n                if (required_feature == supported_feature) {\n                    is_supported = true;\n                    break;\n                }\n            }\n            if (!is_supported) {\n                return false;\n            }\n        }\n        return true;\n    }\n\n    inline void check_if_all_features_are_supported_and_abort(const std::vector<FeatureT> &required_features) {\n        if (!check_if_all_features_are_supported(required_features)) {\n            fprintf(stderr, \"[FlashMLA] Error: The chosen implementation does not support all required features.\\n\");\n            fprintf(stderr, \"Required features:\\n\");\n            for (const auto &f : required_features) {\n                fprintf(stderr, \"  - %3d: %s\\n\", static_cast<int>(f), get_dynamic_enum_name(f).c_str());\n            }\n            fprintf(stderr, \"\\n\");\n            fprintf(stderr, \"Supported features:\\n\");\n            for (const auto &supported_feature : get_supported_features()) {\n                fprintf(stderr, \"  - %3d: %s\\n\", static_cast<int>(supported_feature), get_dynamic_enum_name(supported_feature).c_str());\n            }\n            fprintf(stderr, \"\\n\");\n            fprintf(stderr, \"Features that are required but not supported:\\n\");\n            for (const auto &required_feature : required_features) {\n                bool is_supported = false;\n                for (const auto &supported_feature : get_supported_features()) {\n                    if (required_feature == supported_feature) {\n                        is_supported = true;\n                        break;\n                    }\n                }\n                if (!is_supported) {\n                    fprintf(stderr, \"  - %3d: %s\\n\", static_cast<int>(required_feature), get_dynamic_enum_name(required_feature).c_str());\n                }\n            }\n            fprintf(stderr, \"\\n\");\n            Arch cur_gpu_arch = Arch();\n            fprintf(stderr, \"Current GPU: %s, SM %d.%d with %d SMs\\n\", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms);\n            fprintf(stderr, \"This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\\n\");\n            TORCH_CHECK(false, \"The chosen implementation does not support all required features. See message above for details.\");\n        }\n    }\n\n    inline void run(const RunArgT &params, const std::vector<FeatureT> &required_features) {\n        check_if_all_features_are_supported_and_abort(required_features);\n        run_(params, required_features);\n    }\n};\n\n"
  },
  {
    "path": "csrc/api/dense_decode.h",
    "content": "#pragma once\n\n#include <cutlass/half.h>\n#include <cutlass/fast_math.h>\n\n#include \"common.h\"\n#include \"params.h\"\n\n#include \"sm90/decode/dense/splitkv_mla.h\"\n#include \"smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h\"\n#include \"smxx/decode/combine/combine.h\"\n\nstatic std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>\ndense_attn_decode_interface(\n    at::Tensor &q,                               // batch_size x seqlen_q x num_heads x head_size\n    const at::Tensor &kcache,                    // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)\n    const int head_size_v,\n    const at::Tensor &seqlens_k,                 // batch_size\n    const at::Tensor &block_table,               // batch_size x max_num_blocks_per_seq\n    const float softmax_scale,\n    bool is_causal,\n    std::optional<at::Tensor> &tile_scheduler_metadata,   // num_sm_parts x (DecodingSchedMetaSize/4)\n    std::optional<at::Tensor> &num_splits                 // batch_size + 1\n) {\n    // Check arch\n    Arch arch = Arch();\n    if (!arch.is_sm90a()) {\n        TORCH_CHECK(false, \"Dense decode MLA is only supported on SM90a architecture\");\n    }\n\n    // Check data types\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);\n    \n    TORCH_CHECK(kcache.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, \"seqlens_k must have dtype int32\");\n    TORCH_CHECK(block_table.dtype() == torch::kInt32, \"block_table must have dtype torch.int32\");\n\n    // Check device\n    KU_CHECK_DEVICE(q);\n    KU_CHECK_DEVICE(kcache);\n    KU_CHECK_DEVICE(seqlens_k);\n    KU_CHECK_DEVICE(block_table);\n    KU_CHECK_DEVICE(tile_scheduler_metadata);\n    KU_CHECK_DEVICE(num_splits);\n\n    // Check layout\n    TORCH_CHECK(q.stride(-1) == 1, \"q must have contiguous last dimension\");\n    TORCH_CHECK(kcache.stride(-1) == 1, \"kcache must have contiguous last dimension\");\n    KU_CHECK_CONTIGUOUS(seqlens_k);\n    TORCH_CHECK(block_table.stride(-1) == 1, \"block_table must have contiguous last dimension\");\n    KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n    KU_CHECK_CONTIGUOUS(num_splits);\n\n    const auto sizes = q.sizes();\n    const int batch_size = sizes[0];\n    const int seqlen_q_ori = sizes[1];\n    const int num_heads_q = sizes[2];\n    const int head_size_k = sizes[3];\n    TORCH_CHECK(head_size_k == 576 || head_size_k == 512, \"Only head_size_k == 576 or 512 is supported\");\n    TORCH_CHECK(head_size_v == 512, \"Only head_size_v == 576 is supported\");\n    \n    const int max_num_blocks_per_seq = block_table.size(1);\n    const int num_blocks = kcache.size(0);\n    const int page_block_size = kcache.size(1);\n    const int num_heads_k = kcache.size(2);\n    TORCH_CHECK(page_block_size == 64, \"Currently page_block_size must be 64\");\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(num_heads_q % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    \n    if (seqlen_q_ori == 1) { is_causal = false; }\n    \n    const int num_q_heads_per_hk = num_heads_q / num_heads_k;\n    const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;\n    const int num_heads = num_heads_k;\n    q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)\n        .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});\n    int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1);\n\n    KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);\n    KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);\n    KU_CHECK_SHAPE(seqlens_k, batch_size);\n    KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);\n    KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));\n    KU_CHECK_SHAPE(num_splits, batch_size+1);\n\n    at::cuda::CUDAGuard device_guard{(char)q.get_device()};\n\n    auto opts = q.options();\n    at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);\n    at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));\n    KU_CHECK_CONTIGUOUS(out);\n    KU_CHECK_CONTIGUOUS(lse);\n\n    if (!tile_scheduler_metadata.has_value()) {\n        tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));\n        num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));\n        KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n        KU_CHECK_CONTIGUOUS(num_splits);\n    \n        GetDecodeSchedMetaParams get_sched_meta_params = {\n            batch_size, seqlen_q_ori,\n            64,\n            5,\n            -1, -1,\n            nullptr, nullptr,\n            seqlens_k.data_ptr<int>(),\n            (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),\n            num_splits->data_ptr<int>(),\n            num_sm_parts,\n            at::cuda::getCurrentCUDAStream().stream()\n        };\n        smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);\n    } else {\n        KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);\n        KU_CHECK_DTYPE(num_splits, torch::kInt32);\n        KU_CHECK_DEVICE(tile_scheduler_metadata);\n        KU_CHECK_DEVICE(num_splits);\n        KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n        KU_CHECK_CONTIGUOUS(num_splits);\n        KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));\n        KU_CHECK_SHAPE(num_splits, batch_size+1);\n    }\n\n    // Set the sizes\n    DenseAttnDecodeParams params;\n    params.b = batch_size;\n    params.s_q = seqlen_q_ori;\n    params.q_seq_per_hk = q_seq_per_hk;\n    params.seqlens_k_ptr = seqlens_k.data_ptr<int>();\n    params.h_q = num_heads_q;\n    params.h_k = num_heads_k;\n    params.num_blocks = num_blocks;\n    params.q_head_per_hk = num_q_heads_per_hk;\n    params.is_causal = is_causal;\n    params.d = head_size_k;\n    params.d_v = head_size_v;\n    params.scale_softmax = softmax_scale;\n    params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);\n    // Set the pointers and strides.\n    params.q_ptr = q.data_ptr();\n    params.k_ptr = kcache.data_ptr();\n    params.o_ptr = out.data_ptr();\n    params.softmax_lse_ptr = lse.data_ptr<float>();\n    // All stride are in elements, not bytes.\n    params.q_batch_stride = q.stride(0);\n    params.k_batch_stride = kcache.stride(0);\n    params.o_batch_stride = out.stride(0);\n    params.q_row_stride = q.stride(1);\n    params.k_row_stride = kcache.stride(1);\n    params.o_row_stride = out.stride(2);\n    params.q_head_stride = q.stride(2);\n    params.k_head_stride = kcache.stride(2);\n    params.o_head_stride = out.stride(1);\n\n    params.block_table = block_table.data_ptr<int>();\n    params.block_table_batch_stride = block_table.stride(0);\n    params.page_block_size = page_block_size;\n    \n    params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();\n    params.num_sm_parts = num_sm_parts;\n    params.num_splits_ptr = num_splits->data_ptr<int>();\n\n    const int total_num_splits = batch_size + params.num_sm_parts;\n    at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));\n    at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));\n    KU_CHECK_CONTIGUOUS(lse_accum);\n    KU_CHECK_CONTIGUOUS(out_accum);\n    params.total_num_splits = total_num_splits;\n    params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();\n    params.oaccum_ptr = out_accum.data_ptr<float>();\n\n    params.stream = at::cuda::getCurrentCUDAStream().stream();\n\n    if (q_dtype == torch::kBFloat16) {\n        sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);\n    } else if (q_dtype == torch::kHalf) {\n#ifdef FLASH_MLA_DISABLE_FP16\n        TORCH_CHECK(false, \"FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.\");\n#else\n        sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);\n#endif\n    } else {\n        TORCH_CHECK(false, \"Unsupported dtype for dense MLA on SM90\");\n    }\n\n    CombineParams combine_params = {\n        batch_size, seqlen_q_ori,\n        num_heads_q, head_size_v,\n\n        params.softmax_lse_ptr,\n        params.o_ptr,\n        num_heads*q_seq_per_hk, num_heads_q,\n        num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,\n\n        params.softmax_lseaccum_ptr,\n        params.oaccum_ptr,\n        num_heads*q_seq_per_hk, num_heads_q,\n        num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,\n\n        params.tile_scheduler_metadata_ptr,\n        params.num_splits_ptr,\n        params.num_sm_parts,\n\n        nullptr,\n        at::cuda::getCurrentCUDAStream().stream()\n    };\n\n    if (q_dtype == torch::kBFloat16) {\n        smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);\n    } else if (q_dtype == torch::kHalf) {\n#ifndef FLASH_MLA_DISABLE_FP16\n        smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);\n#endif\n    } else {\n        TORCH_CHECK(false, \"Unsupported tensor dtype for query\");\n    }\n\n    out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)\n            .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});\n    lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)\n            .reshape({batch_size, num_heads_q, seqlen_q_ori});\n\n    return {out, lse, tile_scheduler_metadata, num_splits};\n}\n"
  },
  {
    "path": "csrc/api/dense_fwd.h",
    "content": "#pragma once\n\n#include \"common.h\"\n\n#include \"sm100/prefill/dense/interface.h\"\n"
  },
  {
    "path": "csrc/api/sparse_decode.h",
    "content": "#pragma once\n\n#include \"common.h\"\n\n#include \"params.h\"\n\n#include \"sm90/decode/sparse_fp8/splitkv_mla.h\"\n#include \"sm100/decode/head64/kernel.h\"\n#include \"sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h\"\n#include \"smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h\"\n#include \"smxx/decode/combine/combine.h\"\n\n// Feature set of sparse decoding kernels\nenum class DecodeFeatures : int {\n    HEAD_64,\n    HEAD_128,\n\n    HEAD_DIM_576,\n    HEAD_DIM_512,\n\n    V32_KVCACHE_FORMAT,\n    MODEL1_KVCACHE_FORMAT,\n\n    ATTN_SINK,\n    TOPK_LENGTH,\n    EXTRA_KVCACHE,\n    EXTRA_TOPK_LENGTH\n};\n\nstruct DecodeImplMeta {\n    int num_sm_parts;\n    int fixed_overhead_num_blocks;\n    int block_size_topk;\n};\n\nclass DecodeImplBase : public ImplBase<\n    SparseAttnDecodeParams,\n    DecodeFeatures\n> {\npublic:\n    virtual DecodeImplMeta get_meta(int h_q, int s_q) = 0;\n};\n\nclass Decode_Sm90_Impl : public DecodeImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        DecodeFeatures::HEAD_64,\n        DecodeFeatures::HEAD_128,\n        DecodeFeatures::HEAD_DIM_512,\n        DecodeFeatures::HEAD_DIM_576,\n        DecodeFeatures::V32_KVCACHE_FORMAT,\n        DecodeFeatures::MODEL1_KVCACHE_FORMAT,\n        DecodeFeatures::ATTN_SINK,\n        DecodeFeatures::TOPK_LENGTH,\n        DecodeFeatures::EXTRA_KVCACHE,\n        DecodeFeatures::EXTRA_TOPK_LENGTH\n    )\n\npublic:\n    DecodeImplMeta get_meta(int h_q, int s_q) override {\n        Arch arch = Arch();\n        return {\n            std::max(arch.num_sms / s_q / (h_q/64), 1),\n            5,\n            64\n        };\n    }\n\nprotected:\n    void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {\n            DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() {\n                sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);\n            });\n        });\n    }\n};\n\nclass Decode_Sm100_Head64_Impl : public DecodeImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        DecodeFeatures::HEAD_64,\n        DecodeFeatures::HEAD_DIM_512,\n        DecodeFeatures::HEAD_DIM_576,\n        DecodeFeatures::V32_KVCACHE_FORMAT,\n        DecodeFeatures::MODEL1_KVCACHE_FORMAT,\n        DecodeFeatures::ATTN_SINK,\n        DecodeFeatures::TOPK_LENGTH,\n        DecodeFeatures::EXTRA_KVCACHE,\n        DecodeFeatures::EXTRA_TOPK_LENGTH\n    )\n\npublic:\n    DecodeImplMeta get_meta(int h_q, int s_q) override {\n        Arch arch = Arch();\n        return {\n            std::max(arch.num_sms / s_q, 1),\n            5,\n            64\n        };\n    }\n\nprotected:\n    void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {\n            sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(params);\n        });\n    }\n};\n\n\n// An implementation that calls the head64 kernel twice to process head128\n// Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f\nclass Decode_Sm100_Head64x2_Impl : public DecodeImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        DecodeFeatures::HEAD_128,\n        DecodeFeatures::HEAD_DIM_512,\n        DecodeFeatures::HEAD_DIM_576,\n        DecodeFeatures::V32_KVCACHE_FORMAT,\n        DecodeFeatures::MODEL1_KVCACHE_FORMAT,\n        DecodeFeatures::ATTN_SINK,\n        DecodeFeatures::TOPK_LENGTH,\n        DecodeFeatures::EXTRA_KVCACHE,\n        DecodeFeatures::EXTRA_TOPK_LENGTH\n    )\n\npublic:\n    DecodeImplMeta get_meta(int h_q, int s_q) override {\n        Arch arch = Arch();\n        return {\n            std::max(arch.num_sms / s_q, 1),\n            5,\n            64\n        };\n    }\n\nprotected:\n    void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {\n            for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) {\n                SparseAttnDecodeParams cur_params = params;\n                cur_params.q += start_head_idx * params.stride_q_h_q;\n                if (cur_params.attn_sink) {\n                    cur_params.attn_sink += start_head_idx;\n                }\n                cur_params.lse += start_head_idx;\n                cur_params.out += start_head_idx * params.stride_o_h_q;\n                cur_params.lse_accum += start_head_idx;\n                cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q;\n                cur_params.h_q = 64;\n                sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(cur_params);\n            }\n        });\n    }\n};\n\n\nclass Decode_Sm100_Head128_Impl : public DecodeImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        DecodeFeatures::HEAD_128,\n        DecodeFeatures::HEAD_DIM_512,\n        DecodeFeatures::MODEL1_KVCACHE_FORMAT,\n        DecodeFeatures::ATTN_SINK,\n        DecodeFeatures::TOPK_LENGTH,\n        DecodeFeatures::EXTRA_KVCACHE,\n        DecodeFeatures::EXTRA_TOPK_LENGTH\n    )\n\npublic:\n    DecodeImplMeta get_meta(int h_q, int s_q) override {\n        Arch arch = Arch();\n        return {\n            std::max(arch.num_sms / s_q / 2, 1),\n            3,\n            64\n        };\n    }\n\nprotected:\n    void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {\n        sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(params);\n    }\n};\n\nstatic std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>\nsparse_attn_decode_interface(\n    const at::Tensor &q,   // [b, s_q, h_q, d_qk]\n    const at::Tensor &kv,   // [num_blocks, page_block_size, h_k, d_qk]\n    const at::Tensor &indices,    // [b, s_q, topk]\n    const std::optional<at::Tensor> &topk_length,   // [b, s_q]\n    const std::optional<at::Tensor> &attn_sink, // [h_q]\n    std::optional<at::Tensor> &tile_scheduler_metadata,   // num_sm_parts x (DecodingSchedMetaSize/4)\n    std::optional<at::Tensor> &num_splits,                // batch_size + 1\n    const std::optional<at::Tensor> &extra_kv,\n    const std::optional<at::Tensor> &extra_indices,\n    const std::optional<at::Tensor> &extra_topk_length,\n    int d_v,\n    float sm_scale\n) {\n    using bf16 = cutlass::bfloat16_t;\n\n    // Check the architecture\n    Arch arch = Arch();\n\n    KU_CHECK_NDIM(q, 4);\n    KU_CHECK_NDIM(kv, 4);\n    KU_CHECK_NDIM(indices, 3);\n\n    int b = q.size(0);\n    int s_q = q.size(1);\n    int h_q = q.size(2);\n    int d_qk = q.size(3);\n    int num_blocks = kv.size(0);\n    int page_block_size = kv.size(1);\n    int h_kv = kv.size(2);\n    int topk = indices.size(2);\n\n    bool have_topk_length = topk_length.has_value();\n    bool have_extra_kcache = extra_kv.has_value();\n    bool have_extra_topk_length = extra_topk_length.has_value();\n    bool have_attn_sink = attn_sink.has_value();\n\n    int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0;\n    if (have_extra_kcache) {\n        extra_num_blocks = extra_kv->size(0);\n        extra_page_block_size = extra_kv->size(1);\n    }\n    if (extra_indices.has_value()) {\n        extra_topk = extra_indices->size(-1);\n    }\n\n    // metadata sanity check\n    TORCH_CHECK(b > 0);\n    TORCH_CHECK(s_q > 0);\n    TORCH_CHECK(h_q > 0);\n    TORCH_CHECK(h_kv == 1, \"Currently only MQA (i.e. h_kv == 1) is supported for sparse decoding\");\n    TORCH_CHECK(d_qk == 576 || d_qk == 512, \"Only head_size_k == 576 or 512 is supported for sparse decoding\");\n    TORCH_CHECK(d_v == 512, \"Only head_size_v == 512 is supported for sparse decoding\");\n    TORCH_CHECK(topk > 0);\n\n    if (have_extra_kcache) {\n        TORCH_CHECK(extra_indices.has_value(), \"extra_indices_in_kvcache must be provided when extra_kcache is provided for sparse attention\");\n    } else {\n        TORCH_CHECK(!extra_indices.has_value(), \"extra_indices_in_kvcache must not be provided when extra_k_cache is not provided\");\n        TORCH_CHECK(!extra_topk_length.has_value(), \"extra_topk_length must not be provided when extra_k_cache is not provided\");\n    }\n\n    // Check device\n    KU_CHECK_DEVICE(q);\n    KU_CHECK_DEVICE(kv);\n    KU_CHECK_DEVICE(indices);\n    KU_CHECK_DEVICE(topk_length);\n    KU_CHECK_DEVICE(attn_sink);\n    KU_CHECK_DEVICE(tile_scheduler_metadata);\n    KU_CHECK_DEVICE(num_splits);\n    KU_CHECK_DEVICE(extra_kv);\n    KU_CHECK_DEVICE(extra_indices);\n    KU_CHECK_DEVICE(extra_topk_length);\n\n    // Check data type\n    KU_CHECK_DTYPE(q, torch::kBFloat16);\n    TORCH_CHECK(kv.dtype() == torch::kFloat8_e4m3fn || kv.dtype() == torch::kInt8 || kv.dtype() == torch::kUInt8, \"key must have dtype fp8_e4m3fn, int8 or uint8\");\n    if (extra_kv.has_value()) {\n        TORCH_CHECK(extra_kv->dtype() == torch::kFloat8_e4m3fn || extra_kv->dtype() == torch::kInt8 || extra_kv->dtype() == torch::kUInt8, \"extra k cache must have dtype fp8_e4m3fn, int8 or uint8\");\n    }\n    KU_CHECK_DTYPE(indices, torch::kInt32);\n    KU_CHECK_DTYPE(topk_length, torch::kInt32);\n    KU_CHECK_DTYPE(attn_sink, torch::kFloat32);\n    KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);\n    KU_CHECK_DTYPE(num_splits, torch::kInt32);\n    KU_CHECK_DTYPE(extra_indices, torch::kInt32);\n    KU_CHECK_DTYPE(extra_topk_length, torch::kInt32);\n    \n    // Check layout\n    KU_CHECK_LAST_DIM_CONTIGUOUS(q);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(kv);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(indices);\n    KU_CHECK_CONTIGUOUS(topk_length);\n    KU_CHECK_CONTIGUOUS(attn_sink);\n\n    KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n    KU_CHECK_CONTIGUOUS(num_splits);\n\n    KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices);\n    KU_CHECK_CONTIGUOUS(extra_topk_length);\n    \n    // Check shape\n    KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk);\n    {\n        int bytes_per_token;\n        if (d_qk == 576 && d_v == 512) {\n            // V3.2 style\n            bytes_per_token = 512 + 64*2 + (512/128)*4;\n        } else if (d_qk == 512 && d_v == 512) {\n            // MODEL1 style\n            bytes_per_token = 448 + 64*2 + (448/64)*1 + 1;\n        } else {\n            TORCH_CHECK(false, \"Unsupported head sizes for is_fp8_kvcache == True\");\n        }\n        KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token);\n        KU_CHECK_SHAPE(extra_kv, extra_num_blocks, extra_page_block_size, h_kv, bytes_per_token);\n        TORCH_CHECK(kv.stride(1) == bytes_per_token, \"The whole block must be contiguous when is_fp8_cache is True for kv cache\");\n        if (extra_kv.has_value()) {\n            TORCH_CHECK(extra_kv->stride(1) == bytes_per_token, \"The whole block must be contiguous when is_fp8_cache is True for extra kv cache\");\n        }\n    }\n    KU_CHECK_SHAPE(indices, b, s_q, topk);\n    KU_CHECK_SHAPE(topk_length, b);\n    KU_CHECK_SHAPE(attn_sink, h_q);\n    KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk);\n    KU_CHECK_SHAPE(extra_topk_length, b);\n\n    at::cuda::CUDAGuard device_guard{(char)q.get_device()};\n    auto opts = q.options();\n\n    at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts);\n    at::Tensor lse = torch::empty({b, s_q, h_q}, opts.dtype(at::kFloat));\n\n    ModelType model_type;\n    if (d_qk == 576) {\n        model_type = ModelType::V32;\n    } else if (d_qk == 512) {\n        model_type = ModelType::MODEL1;\n    } else {\n        TORCH_CHECK(false, \"Unsupported d_qk: \", d_qk);\n    }\n\n    std::vector<DecodeFeatures> features;\n    if (h_q == 64) {\n        features.push_back(DecodeFeatures::HEAD_64);\n    } else if (h_q == 128) {\n        features.push_back(DecodeFeatures::HEAD_128);\n    } else {\n        TORCH_CHECK(false, \"Unsupported h_q: \", h_q);\n    }\n    if (d_qk == 576) {\n        features.push_back(DecodeFeatures::HEAD_DIM_576);\n    } else if (d_qk == 512) {\n        features.push_back(DecodeFeatures::HEAD_DIM_512);\n    } else {\n        TORCH_CHECK(false, \"Unsupported d_qk: \", d_qk);\n    }\n    if (model_type == ModelType::V32) {\n        features.push_back(DecodeFeatures::V32_KVCACHE_FORMAT);\n    } else if (model_type == ModelType::MODEL1) {\n        features.push_back(DecodeFeatures::MODEL1_KVCACHE_FORMAT);\n    } else {\n        TORCH_CHECK(false, \"Unsupported model type: \", (int)model_type);\n    }\n    if (have_attn_sink) {\n        features.push_back(DecodeFeatures::ATTN_SINK);\n    }\n    if (have_topk_length) {\n        features.push_back(DecodeFeatures::TOPK_LENGTH);\n    }\n    if (have_extra_kcache) {\n        features.push_back(DecodeFeatures::EXTRA_KVCACHE);\n    }\n    if (have_extra_topk_length) {\n        features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH);\n    }\n\n    DecodeImplBase* impl;\n    if (arch.is_sm100f()) {\n        if (h_q == 64) {\n            impl = new Decode_Sm100_Head64_Impl();\n        } else if (h_q == 128) {\n            if (d_qk == 576) {\n                impl = new Decode_Sm100_Head64x2_Impl();\n            } else if (d_qk == 512) {\n                impl = new Decode_Sm100_Head128_Impl();\n            } else {\n                TORCH_CHECK(false, \"Unsupported d_qk: \", d_qk);\n            }\n        } else {\n            TORCH_CHECK(false, \"Unsupported h_q: \", h_q);\n        }\n    } else if (arch.is_sm90a()) {\n        impl = new Decode_Sm90_Impl();\n    } else {\n        TORCH_CHECK(false, \"Unsupported architecture for sparse decode fwd\");\n    }\n\n    DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q);\n\n    SparseAttnDecodeParams params = {\n        b, s_q, h_q, h_kv, d_qk, d_v,\n        sm_scale, sm_scale * LOG_2_E,\n        num_blocks, page_block_size, topk,\n        model_type,\n\n        (bf16*)q.data_ptr(),\n        (bf16*)kv.data_ptr(),\n        (int*)indices.data_ptr(),\n        ku::get_optional_tensor_ptr<int>(topk_length),\n        ku::get_optional_tensor_ptr<float>(attn_sink),\n        (float*)lse.data_ptr(),\n        (bf16*)out.data_ptr(),\n\n        extra_num_blocks, extra_page_block_size, extra_topk,\n        ku::get_optional_tensor_ptr<bf16>(extra_kv),\n        ku::get_optional_tensor_ptr<int>(extra_indices),\n        ku::get_optional_tensor_ptr<int>(extra_topk_length),\n\n        int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(q.stride(2)),\n        int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),\n        int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),\n        int64_stride_to_int(lse.stride(0)), int64_stride_to_int(lse.stride(1)),\n        int64_stride_to_int(out.stride(0)), int64_stride_to_int(out.stride(1)), int64_stride_to_int(out.stride(2)),\n\n        have_extra_kcache ? int64_stride_to_int(extra_kv->stride(0)) : 0,\n        have_extra_kcache ? int64_stride_to_int(extra_kv->stride(1)) : 0,\n        have_extra_kcache ? int64_stride_to_int(extra_indices->stride(0)) : 0,\n        have_extra_kcache ? int64_stride_to_int(extra_indices->stride(1)) : 0,\n        at::cuda::getCurrentCUDAStream().stream()\n    };\n\n    // Get MLA metadata if necessary\n    at::Tensor o_accum, lse_accum;\n    if (!tile_scheduler_metadata.has_value()) {\n        tile_scheduler_metadata = torch::empty({impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));\n        num_splits = torch::empty({b+1}, opts.dtype(torch::kInt32));\n        KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n        KU_CHECK_CONTIGUOUS(num_splits);\n\n        GetDecodeSchedMetaParams get_sched_meta_params = {\n            b, s_q,\n            impl_meta.block_size_topk,\n            impl_meta.fixed_overhead_num_blocks,\n            topk,\n            extra_topk,\n            ku::get_optional_tensor_ptr<int>(topk_length),\n            ku::get_optional_tensor_ptr<int>(extra_topk_length),\n            nullptr,\n            (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),\n            num_splits->data_ptr<int>(),\n            impl_meta.num_sm_parts,\n            at::cuda::getCurrentCUDAStream().stream()\n        };\n        smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);\n    }\n    // Stick the metadata pointers to `params`\n    KU_CHECK_DEVICE(tile_scheduler_metadata);\n    KU_CHECK_DEVICE(num_splits);\n    KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);\n    KU_CHECK_DTYPE(num_splits, torch::kInt32);\n    KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);\n    KU_CHECK_CONTIGUOUS(num_splits);\n    KU_CHECK_SHAPE(tile_scheduler_metadata, impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));\n    KU_CHECK_SHAPE(num_splits, b+1);\n    params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();\n    params.num_splits_ptr = num_splits->data_ptr<int>();\n    params.num_sm_parts = impl_meta.num_sm_parts;\n\n    // Allocate intermediate buffers for split-KV\n    const int total_num_splits = b + impl_meta.num_sm_parts;\n    lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat));\n    o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat));\n    KU_CHECK_CONTIGUOUS(lse_accum);\n    KU_CHECK_CONTIGUOUS(o_accum);\n    params.lse_accum = lse_accum.data_ptr<float>();\n    params.o_accum = o_accum.data_ptr<float>();\n    params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0));\n    params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1));\n    params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0));\n    params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1));\n    params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2));\n\n    impl->run(params, features);\n    \n    CombineParams combine_params = {\n        b, s_q, h_q, d_v,\n\n        params.lse,\n        params.out,\n        params.stride_lse_b, params.stride_lse_s_q,\n        params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q,\n\n        params.lse_accum,\n        params.o_accum,\n        params.stride_lse_accum_split, params.stride_lse_accum_s_q,\n        params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q,\n\n        params.tile_scheduler_metadata_ptr,\n        params.num_splits_ptr,\n        params.num_sm_parts,\n\n        ku::get_optional_tensor_ptr<float>(attn_sink),\n        at::cuda::getCurrentCUDAStream().stream()\n    };\n    smxx::decode::run_flash_mla_combine_kernel<bf16>(combine_params);\n\n    delete impl;\n\n    return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits};\n}\n"
  },
  {
    "path": "csrc/api/sparse_fwd.h",
    "content": "#pragma once\n\n#include \"common.h\"\n\n#include \"params.h\"\n\n#include \"sm90/prefill/sparse/phase1.h\"\n#include \"sm100/prefill/sparse/fwd/head128/phase1.h\"\n#include \"sm100/prefill/sparse/fwd/head64/phase1.h\"\n#include \"sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h\"\n\nenum class FwdFeatures : int {\n    HEAD_64,\n    HEAD_128,\n\n    HEAD_DIM_576,\n    HEAD_DIM_512,\n\n    ATTN_SINK,\n    SINK_LSE,\n    TOPK_LENGTH\n};\n\nclass FwdImplBase : public ImplBase<\n    SparseAttnFwdParams,\n    FwdFeatures\n> {};\n\nclass Fwd_Sm90_Impl : public FwdImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        FwdFeatures::HEAD_64,\n        FwdFeatures::HEAD_128,\n        FwdFeatures::HEAD_DIM_512,\n        FwdFeatures::HEAD_DIM_576,\n        FwdFeatures::ATTN_SINK,\n        FwdFeatures::SINK_LSE,\n        FwdFeatures::TOPK_LENGTH\n    )\n\nprotected:\n    void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {\n            DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {\n                sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);\n            });\n        });\n    }\n};\n\nclass Fwd_Sm100_Head64_Impl : public FwdImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        FwdFeatures::HEAD_64,\n        FwdFeatures::HEAD_DIM_512,\n        FwdFeatures::HEAD_DIM_576,\n        FwdFeatures::ATTN_SINK,\n        FwdFeatures::SINK_LSE,\n        FwdFeatures::TOPK_LENGTH\n    )\n\nprotected:\n    void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {\n            sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);\n        });\n    }\n};\n\nclass Fwd_Sm100_Head128_Impl : public FwdImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        FwdFeatures::HEAD_128,\n        FwdFeatures::HEAD_DIM_512,\n        FwdFeatures::HEAD_DIM_576,\n        FwdFeatures::ATTN_SINK,\n        FwdFeatures::SINK_LSE,\n        FwdFeatures::TOPK_LENGTH\n    )\n\nprotected:\n    void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {\n        DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {\n            sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);\n        });\n    }\n};\n\nclass Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {\n    DECLARE_SUPPORTED_FEATURES(\n        FwdFeatures::HEAD_128,\n        FwdFeatures::HEAD_DIM_512,\n        FwdFeatures::ATTN_SINK,\n        FwdFeatures::SINK_LSE,\n        FwdFeatures::TOPK_LENGTH\n    )\n\nprotected:\n    void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {\n        sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);\n    }\n};\n\nstatic std::vector<at::Tensor> sparse_attn_prefill_interface(\n    const at::Tensor &q,\n    const at::Tensor &kv,\n    const at::Tensor &indices,\n    float sm_scale,\n    int d_v,\n    const std::optional<at::Tensor> &attn_sink,\n    const std::optional<at::Tensor> &topk_length\n) {\n    using bf16 = cutlass::bfloat16_t;\n    \n    Arch arch = Arch();\n    bool is_sm90a = arch.is_sm90a();\n    bool is_sm100f = arch.is_sm100f();\n    TORCH_CHECK(is_sm90a || is_sm100f, \"Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures.\");\n\n    KU_CHECK_NDIM(q, 3);\n    KU_CHECK_NDIM(kv, 3);\n    KU_CHECK_NDIM(indices, 3);\n    KU_CHECK_NDIM(attn_sink, 1);\n    KU_CHECK_NDIM(topk_length, 1);\n\n    int s_q = q.size(0);\n    int s_kv = kv.size(0);\n    int h_q = q.size(1);\n    int h_kv = kv.size(1);\n    int d_qk = q.size(2);\n    int topk = indices.size(2);\n    bool have_topk_length = topk_length.has_value();\n\n    TORCH_CHECK(d_qk == 576 || d_qk == 512, \"Invalid d_qk: \", d_qk);\n    TORCH_CHECK(d_v == 512, \"Invalid d_v\", d_v);\n    \n    KU_CHECK_DEVICE(q);\n    KU_CHECK_DEVICE(kv);\n    KU_CHECK_DEVICE(indices);\n    KU_CHECK_DEVICE(attn_sink);\n    KU_CHECK_DEVICE(topk_length);\n    \n    KU_CHECK_DTYPE(q, torch::kBFloat16);\n    KU_CHECK_DTYPE(kv, torch::kBFloat16);\n    KU_CHECK_DTYPE(indices, torch::kInt32);\n    KU_CHECK_DTYPE(attn_sink, torch::kFloat32);\n    KU_CHECK_DTYPE(topk_length, torch::kInt32);\n    \n    KU_CHECK_SHAPE(q, s_q, h_q, d_qk);\n    KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk);\n    KU_CHECK_SHAPE(indices, s_q, h_kv, topk);\n    KU_CHECK_SHAPE(attn_sink, h_q);\n    KU_CHECK_SHAPE(topk_length, s_q);\n    \n    KU_CHECK_LAST_DIM_CONTIGUOUS(q);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(kv);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(indices);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink);\n    KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length);\n    \n    // Allocate results and buffers\n    at::cuda::CUDAGuard device_guard{(char)q.get_device()};\n    auto opts = q.options();\n    \n    at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);\n    at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));\n    at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));\n    KU_CHECK_CONTIGUOUS(out);\n    KU_CHECK_CONTIGUOUS(lse);\n    KU_CHECK_CONTIGUOUS(max_logits);\n\n    SparseAttnFwdParams params = {\n        s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,\n        sm_scale, sm_scale * LOG_2_E,\n\n        (bf16*)q.data_ptr(),\n        (bf16*)kv.data_ptr(),\n        (int*)indices.data_ptr(),\n        ku::get_optional_tensor_ptr<float>(attn_sink),\n        ku::get_optional_tensor_ptr<int>(topk_length),\n\n        int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),\n        int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),\n        int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),\n\n        (bf16*)out.data_ptr(),\n        (float*)max_logits.data_ptr(),\n        (float*)lse.data_ptr(),\n\n        arch.num_sms,\n        at::cuda::getCurrentCUDAStream().stream()\n    };\n\n    std::vector<FwdFeatures> required_features;\n    if (h_q == 64) {\n        required_features.push_back(FwdFeatures::HEAD_64);\n    } else if (h_q == 128) {\n        required_features.push_back(FwdFeatures::HEAD_128);\n    } else {\n        TORCH_CHECK(false, \"Unsupported h_q: \", h_q);\n    }\n    if (d_qk == 576) {\n        required_features.push_back(FwdFeatures::HEAD_DIM_576);\n    } else if (d_qk == 512) {\n        required_features.push_back(FwdFeatures::HEAD_DIM_512);\n    } else {\n        TORCH_CHECK(false, \"Unsupported d_qk: \", d_qk);\n    }\n    if (attn_sink.has_value()) {\n        required_features.push_back(FwdFeatures::ATTN_SINK);\n    }\n    if (have_topk_length) {\n        required_features.push_back(FwdFeatures::TOPK_LENGTH);\n    }\n\n    if (is_sm90a) {\n        Fwd_Sm90_Impl fwd_impl;\n        fwd_impl.run(params, required_features);\n    } else if (is_sm100f) {\n        if (h_q == 64) {\n            Fwd_Sm100_Head64_Impl fwd_impl;\n            fwd_impl.run(params, required_features);\n        } else if (h_q == 128) {\n            Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;\n            Fwd_Sm100_Head128_Impl regular_impl;\n            bool use_small_topk_impl = false;\n            if (\n                (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||\n                !regular_impl.check_if_all_features_are_supported(required_features)\n            ) {\n                use_small_topk_impl = true;\n            }\n            if (use_small_topk_impl) {\n                small_topk_impl.run(params, required_features);\n            } else {\n                regular_impl.run(params, required_features);\n            }\n        } else {\n            TORCH_CHECK(false, \"Unsupported h_q: \", h_q);\n        }\n    } else {\n        TORCH_CHECK(false, \"Unsupported architecture\");\n    }\n\n    return {out, max_logits, lse};\n}\n"
  },
  {
    "path": "csrc/defines.h",
    "content": "#pragma once\n\n#include <cutlass/bfloat16.h>\n#include <cutlass/arch/barrier.h>\n\nusing bf16 = cutlass::bfloat16_t;\nusing fp8 = cutlass::float_e4m3_t;\nusing transac_bar_t = cutlass::arch::ClusterTransactionBarrier;\nusing cutlass::arch::fence_view_async_shared;\nusing cutlass::arch::fence_barrier_init;\nusing cutlass::arch::NamedBarrier;\n\nstruct int32x8_t {\n    int a0, a1, a2, a3, a4, a5, a6, a7;\n};\n\nstruct float8 {\n    float2 a01, a23, a45, a67;\n};\n\nstruct bf16x8 {\n    __nv_bfloat162 a01;\n    __nv_bfloat162 a23;\n    __nv_bfloat162 a45;\n    __nv_bfloat162 a67;\n};\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/common/common.h",
    "content": "#pragma once\n\nnamespace kerutils {}\n\n#define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print(\"\\n\"); }\n\nnamespace ku = kerutils;\n\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/common.h",
    "content": "/*\nCommon data types and macros that are used across the kerutils library.\n*/\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp8.h>\n\n#include <cutlass/bfloat16.h>\n#include <cutlass/arch/barrier.h>\n#include <cute/config.hpp>  // For CUTE_DEVICE\n\nnamespace kerutils {\n\n// Cache hints\nenum class CacheHint {\n    EVICT_FIRST,\n    EVICT_NORMAL,\n    EVICT_LAST,\n    EVICT_UNCHANGED,\n    NO_ALLOCATE\n};\n\n// Prefetch size\nenum class PrefetchSize {\n    B64,\n    B128,\n    B256\n};\n\nusing nvbf16 = __nv_bfloat16;\nusing nvbf16x2 = __nv_bfloat162;\nusing nve4m3 = __nv_fp8_e4m3;\nusing nve4m3x2 = __nv_fp8x2_e4m3;\nusing nve4m3x4 = __nv_fp8x4_e4m3;\n\nusing bf16 = cutlass::bfloat16_t;\nusing transac_bar_t = cutlass::arch::ClusterTransactionBarrier;\n\n}\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))\n#define KERUTILS_ENABLE_SM80\n#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))\nstatic_assert(false, \"kerutils doesn't support SM architectures below SM80\");\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))\n#define KERUTILS_ENABLE_SM90\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000))\n#define KERUTILS_ENABLE_SM90A\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))\n#define KERUTILS_ENABLE_SM100\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200))\n#define KERUTILS_ENABLE_SM100A\n#endif\n\n#if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n#define KERUTILS_ENABLE_SM80\n#define KERUTILS_ENABLE_SM90\n#define KERUTILS_ENABLE_SM90A\n#define KERUTILS_ENABLE_SM100\n#define KERUTILS_ENABLE_SM100A\n#endif\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/device.cuh",
    "content": "#pragma once\n\n#include \"kerutils/common/common.h\"\n\n#include \"common.h\"\n#include \"sm80/intrinsics.cuh\"\n#include \"sm80/helpers.cuh\"\n#include \"sm90/intrinsics.cuh\"\n#include \"sm90/helpers.cuh\"\n#include \"sm100/intrinsics.cuh\"\n#include \"sm100/helpers.cuh\"\n#include \"sm100/gemm.cuh\"\n#include \"sm100/tma_cta_group2_nosplit.cuh\"\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm100/gemm.cuh",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include <kerutils/device/common.h>\n\nnamespace cute {\n\n// Extensions to CuTe\n// CuTe don't support UTCMMA with .ws, so we add it here\n// Besides, CuTe's UTCMMA has an `elect_one_sync()` inside which is really disgusting, so we have our own variant without `elect_one_sync()` here\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>\nstruct SM100_MMA_F16BF16_WS_TS_NOELECT\n{\n  static_assert(M == 32 || M == 64 || M == 128, \"SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.\");\n  static_assert(N == 64 || N == 128 || N == 256,\n                \"SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128\");\n\n  using DRegisters = void;\n  using ARegisters = uint64_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint32_t const& tmem_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"r\"(tmem_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC));\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>\nstruct MMA_Traits<SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>; // Actually this should be \"duplicated\", however, our great CuTe doesn't allow us to set it to \"duplicated\", so we just set it to NonInterleaved for a correct address calculation\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;\n\n  // Logical shape-K is always 256 bits; transform to units of elements\n  static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_1>;\n  using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n  using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<N>>>>;\n  using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_tmem<TA>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint32_t tmem_a = raw_pointer_cast(A.data());\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,\n                  M, N,\n                  a_major, b_major,\n                  a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>\nstruct SM100_MMA_F16BF16_WS_SS_NOELECT\n{\n  static_assert(M == 32 || M == 64 || M == 128, \"SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.\");\n  static_assert(N == 64 || N == 128 || N == 256,\n                \"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128\");\n\n  using DRegisters = void;\n  using ARegisters = uint64_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint64_t const& desc_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"l\"(desc_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC));\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>\nstruct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,\n                                M, N, a_major, b_major,\n                                a_neg, b_neg>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::smem_desc<a_major>;\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;\n\n  // Logical shape-K is always 256bits, transform to units of elements\n  static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_1>;\n  using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n  using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<N>>>>;\n  using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_rmem<TA>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint64_t desc_a = A[0];\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,\n                  M, N, a_major, b_major,\n                  a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,\n          UMMA::Saturate c_sat = UMMA::Saturate::False>\nstruct SM100_MMA_F16BF16_2x1SM_TS_NOELECT\n{\n  static_assert(M == 128 || M == 256, \"SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.\");\n  static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), \"SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 16 between 16 and 256.\");\n  static_assert(a_major == UMMA::Major::K, \"SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed\");\n\n  using DRegisters = void;\n  using ARegisters = uint32_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint32_t const& tmem_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)\n    uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"r\"(tmem_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC),\n        \"r\"(mask[0]), \"r\"(mask[1]), \"r\"(mask[2]), \"r\"(mask[3]),\n        \"r\"(mask[4]), \"r\"(mask[5]), \"r\"(mask[6]), \"r\"(mask[7]));\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED\");\n#endif\n  }\n};\n\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,\n          UMMA::Saturate c_sat>\nstruct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,\n                                     M, N,\n                                     a_major, b_major,\n                                     a_neg, b_neg, c_sat>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;\n\n  // Size of instructions' K extent is always 256 bits; convert to units of element\n  constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_2>;\n  using ALayout = Layout<Shape <      _2,Shape <Int<M/2>,Int<K>>>,\n                         Stride<Int<M/2>,Stride<      _1,Int<M>>>>;\n  using BLayout = Layout<Shape <      _2,Shape <Int<N/2>,Int<K>>>,\n                         Stride<Int<N/2>,Stride<      _1,Int<N>>>>;\n  using CLayout = Layout<Shape <      _2,Shape <Int<M/2>,Int<N>>>,\n                         Stride<Int<M/2>,Stride<      _1,Int<M>>>>;\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_tmem<TA>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint64_t tmem_a = raw_pointer_cast(A.data());\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,\n                       M, N,\n                       a_major, b_major,\n                       a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\n\n\n// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>\nstruct SM100_MMA_F16BF16_2x1SM_SS_NOELECT\n{\n  static_assert(M == 128 || M == 256, \"SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.\");\n  static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), \"SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 16 between 16 and 256.\");\n\n  using DRegisters = void;\n  using ARegisters = uint64_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint64_t const& desc_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)\n    uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"l\"(desc_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC),\n        \"r\"(mask[0]), \"r\"(mask[1]), \"r\"(mask[2]), \"r\"(mask[3]),\n        \"r\"(mask[4]), \"r\"(mask[5]), \"r\"(mask[6]), \"r\"(mask[7]));\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED\");\n#endif\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N,\n          UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>\nstruct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,\n                                     M, N, a_major, b_major,\n                                     a_neg, b_neg>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::smem_desc<a_major>;\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;\n\n  // Size of instructions's K extent is always 256bits, convert to units of element\n  constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_2>;\n  using ALayout = Layout<Shape <      _2,Shape <Int<M/2>,Int<K>>>,\n                         Stride<Int<M/2>,Stride<      _1,Int<M>>>>;\n  using BLayout = Layout<Shape <      _2,Shape <Int<N/2>,Int<K>>>,\n                         Stride<Int<N/2>,Stride<      _1,Int<N>>>>;\n  using CLayout = Layout<Shape <      _2,Shape <Int<M/2>,Int<N>>>,\n                         Stride<Int<M/2>,Stride<      _1,Int<M>>>>;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_rmem<TA>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint64_t desc_a = A[0];\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,\n                       M, N,\n                       a_major, b_major,\n                       a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,\n          UMMA::Saturate c_sat = UMMA::Saturate::False>\nstruct SM100_MMA_F16BF16_TS_NOELECT\n{\n  static_assert(M == 64 || M == 128, \"SM100_MMA_F16BF16_TS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA.\");\n  static_assert((M == 64  && (N % 8 == 0)  && (8 <= N)  && (N <= 256)) ||\n                (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),\n                \"SM100_MMA_F16BF16_TS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\\\n                 or a multiple of 16 between 16 and 256 for M=128.\");\n  static_assert(a_major == UMMA::Major::K, \"SM100_MMA_F16BF16_TS_NOELECT A from TMEM can't be transposed\");\n\n  using DRegisters = void;\n  using ARegisters = uint32_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint32_t const& tmem_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n    uint32_t mask[4] = {0, 0, 0, 0};\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"r\"(tmem_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC),\n        \"r\"(mask[0]), \"r\"(mask[1]), \"r\"(mask[2]), \"r\"(mask[3]));\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,\n          UMMA::Saturate c_sat>\nstruct MMA_Traits<SM100_MMA_F16BF16_TS_NOELECT<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg, c_sat>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_TS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;\n\n  // Logical shape-K is always 256 bits; transform to units of elements\n  static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_1>;\n  using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n  using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<N>>>>;\n  using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_tmem<TA>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint32_t tmem_a = raw_pointer_cast(A.data());\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_TS_NOELECT<a_type, b_type, c_type,\n                  M, N,\n                  a_major, b_major,\n                  a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>\nstruct SM100_MMA_F16BF16_SS_NOELECT\n{\n  static_assert(M == 64 || M == 128, \"SM100_MMA_F16BF16_SS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA.\");\n  static_assert((M == 64  && (N % 8 == 0)  && (8 <= N)  && (N <= 256)) ||\n                (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),\n                \"SM100_MMA_F16BF16_SS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\\\n                 or a multiple of 16 between 16 and 256 for M=128.\");\n\n  using DRegisters = void;\n  using ARegisters = uint64_t[1];\n  using BRegisters = uint64_t[1];\n  using CRegisters = uint32_t[1];\n\n  CUTE_HOST_DEVICE static void\n  fma(uint64_t const& desc_a,\n      uint64_t const& desc_b,\n      uint32_t const& tmem_c,\n      uint32_t const& scaleC,\n      uint64_t const& idescE)\n  {\n    uint32_t mask[4] = {0, 0, 0, 0};\n    asm volatile(\n      \"{\\n\\t\"\n      \".reg .pred p;\\n\\t\"\n      \"setp.ne.b32 p, %4, 0;\\n\\t\"\n      \"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \\n\\t\"\n      \"}\\n\"\n      :\n      : \"r\"(tmem_c), \"l\"(desc_a), \"l\"(desc_b), \"r\"(uint32_t(idescE>>32)), \"r\"(scaleC),\n        \"r\"(mask[0]), \"r\"(mask[1]), \"r\"(mask[2]), \"r\"(mask[3]));\n  }\n};\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>\nstruct MMA_Traits<SM100_MMA_F16BF16_SS_NOELECT<a_type, b_type, c_type,\n                                M, N, a_major, b_major,\n                                a_neg, b_neg>>\n{\n  using ValTypeD = c_type;\n  using ValTypeA = a_type;\n  using ValTypeB = b_type;\n  using ValTypeC = c_type;\n\n  static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, \"SM100_MMA_F16BF16_SS_NOELECT supports 16bit types\");\n\n  using FrgTypeA = UMMA::smem_desc<a_major>;\n  using FrgTypeB = UMMA::smem_desc<b_major>;\n  using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;\n\n  // Logical shape-K is always 256bits, transform to units of elements\n  static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;\n\n  using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;\n  using ThrID   = Layout<_1>;\n  using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n  using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,\n                         Stride<_0,Stride<    _1,Int<N>>>>;\n  using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,\n                         Stride<_0,Stride<    _1,Int<M>>>>;\n\n  UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<\n    a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();\n\n  // Accumulate or overwrite C.   1: read C, 0: ignore C [clear accumulators]\n  UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;\n\n  template <class TD, class DLayout,\n            class TA, class ALayout,\n            class TB, class BLayout,\n            class TC, class CLayout>\n  CUTE_HOST_DEVICE constexpr friend\n  void\n  mma_unpack(MMA_Traits          const& traits,\n             Tensor<TD, DLayout>      & D,\n             Tensor<TA, ALayout> const& A,\n             Tensor<TB, BLayout> const& B,\n             Tensor<TC, CLayout> const& C)\n  {\n    static_assert(is_tmem<TD>::value, \"Expected tmem in MMA_Atom::call\");\n    static_assert(is_rmem<TA>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_rmem<TB>::value, \"Expected desc registers in MMA_Atom::call\");\n    static_assert(is_tmem<TC>::value, \"Expected tmem in MMA_Atom::call\");\n\n    uint64_t desc_a = A[0];\n    uint64_t desc_b = B[0];\n    uint32_t tmem_c = raw_pointer_cast(D.data());\n    uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);\n\n    SM100_MMA_F16BF16_SS_NOELECT<a_type, b_type, c_type,\n                  M, N, a_major, b_major,\n                  a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);\n  }\n};\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm100/helpers.cuh",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// Perform SS UTCMMA\n// sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment\ntemplate<\n    typename TiledMMA,\n    typename TensorA,\n    typename TensorB,\n    typename TensorFragC\n>\nCUTE_DEVICE\nvoid utcmma_ss(\n    TiledMMA &tiled_mma,\n    TensorA sA,\n    TensorB sB,\n    TensorFragC tC_frag,\n    bool clear_accum\n) {\n    using namespace cute;\n    tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;\n    ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter\n    auto sA_frag = thr_mma.partition_fragment_A(sA);\n    auto sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(sA_frag) == size<2>(sB_frag));\n    static_assert(size<1>(sA_frag) == size<1>(tC_frag));\n    static_assert(size<1>(sB_frag) == size<2>(tC_frag));\n    CUTE_UNROLL\n    for (int k = 0; k < size<2>(sA_frag); ++k) {\n        cute::gemm(\n            tiled_mma,\n            sA_frag(_, _, k),\n            sB_frag(_, _, k),\n            tC_frag\n        );\n        tiled_mma.accumulate_ = UMMA::ScaleOut::One;\n    }\n}\n\n// Perform TS UTCMMA\n// sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment\ntemplate<\n    typename TiledMMA,\n    typename TensorA,\n    typename TensorB,\n    typename TensorFragC\n>\nCUTE_DEVICE\nvoid utcmma_ts(\n    TiledMMA &tiled_mma,\n    TensorA tA_frag,\n    TensorB sB,\n    TensorFragC tC_frag,\n    bool clear_accum\n) {\n    using namespace cute;\n    tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;\n    ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter\n    auto sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(tA_frag) == size<2>(sB_frag));\n    CUTE_UNROLL\n    for (int k = 0; k < size<2>(tA_frag); ++k) {\n        cute::gemm(\n            tiled_mma,\n            tA_frag(_, _, k),\n            sB_frag(_, _, k),\n            tC_frag\n        );\n        tiled_mma.accumulate_ = UMMA::ScaleOut::One;\n    }\n}\n\ntemplate<int MN, int K, int SWIZZLE, typename T = bf16>\nstatic constexpr auto make_umma_canonical_k_major_layout() {\n    using namespace cute;\n    using base_atom_type = \\\n        std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16, \n            UMMA::Layout_K_INTER_Atom<T>,\n            std::conditional_t<SWIZZLE == 32,\n                UMMA::Layout_K_SW32_Atom<T>,\n                std::conditional_t<SWIZZLE == 64,\n                    UMMA::Layout_K_SW64_Atom<T>,\n                    std::conditional_t<SWIZZLE == 128,\n                        UMMA::Layout_K_SW128_Atom<T>,\n                        void\n                    >\n                >\n            >\n        >;\n    static_assert(!std::is_same_v<base_atom_type, void>, \"Invalid SWIZZLE value\");\n    return coalesce(tile_to_shape(\n        base_atom_type{},\n        Shape<Int<MN>, Int<K>>{},\n        Step<_1, _2>{}\n    ), Shape<_1, _1>{});\n}\n\ntemplate<int MN, int K, int SWIZZLE, typename T = bf16>\nstatic constexpr auto make_umma_canonical_mn_major_layout() {\n    using namespace cute;\n    using base_atom_type = \\\n        std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16, \n            UMMA::Layout_MN_INTER_Atom<T>,\n            std::conditional_t<SWIZZLE == 32,\n                UMMA::Layout_MN_SW32_Atom<T>,\n                std::conditional_t<SWIZZLE == 64,\n                    UMMA::Layout_MN_SW64_Atom<T>,\n                    std::conditional_t<SWIZZLE == 128,\n                        UMMA::Layout_MN_SW128_Atom<T>,\n                        void\n                    >\n                >\n            >\n        >;\n    static_assert(!std::is_same_v<base_atom_type, void>, \"Invalid SWIZZLE value\");\n    return coalesce(tile_to_shape(\n        base_atom_type{},\n        Shape<Int<MN>, Int<K>>{},\n        Step<_2, _1>{}\n    ), Shape<_1, _1>{});\n}\n\ntemplate<cute::UMMA::Major MAJOR, int MN, int K, int SWIZZLE, typename T = bf16>\nauto make_umma_canonical_layout() {\n    if constexpr (MAJOR == cute::UMMA::Major::K) {\n        return make_umma_canonical_k_major_layout<MN, K, SWIZZLE, T>();\n    } else {\n        return make_umma_canonical_mn_major_layout<MN, K, SWIZZLE, T>();\n    }\n}\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh",
    "content": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)\n// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios\nCUTE_DEVICE\nvoid tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {\n    uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);\n    asm volatile(\n        \"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\\n\"\n        :\n        : \"r\"(smem_addr), \"l\"(desc_ptr), \"r\"(col_idx), \n          \"r\"(row_idxs.x), \"r\"(row_idxs.y), \"r\"(row_idxs.z), \"r\"(row_idxs.w), \n          \"r\"(mbar_addr), \"l\"(cache_hint)\n        : \"memory\"\n    );\n}\n\n// tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)\n// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios\nCUTE_DEVICE\nvoid tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {\n    asm volatile(\n        \"cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\\n\"\n        :\n        : \"l\"(desc_ptr), \"r\"(col_idx), \n          \"r\"(row_idxs.x), \"r\"(row_idxs.y), \"r\"(row_idxs.z), \"r\"(row_idxs.w), \n          \"l\"(cache_hint)\n    );\n}\n\n// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)\ntemplate<bool USE_CTA0_MBAR = false>\nCUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {\n    uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);\n    if constexpr (USE_CTA0_MBAR) {\n        mbar_addr &= cute::Sm100MmaPeerBitMask;\n    }\n    asm volatile(\n        \"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\\n\"\n        :\n        : \"r\"(smem_addr), \"l\"(desc_ptr), \"r\"(col_idx), \n          \"r\"(row_idxs.x), \"r\"(row_idxs.y), \"r\"(row_idxs.z), \"r\"(row_idxs.w), \n          \"r\"(mbar_addr), \"l\"(cache_hint)\n        : \"memory\"\n    );\n}\n\n// Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add)\nCUTE_DEVICE\nfloat2 float2_add(const float2 &a, const float2 &b) {\n    float2 c;\n    asm volatile(\n        \"add.f32x2 %0, %1, %2;\\n\"\n        : \"=l\"(reinterpret_cast<uint64_t&>(c))\n        : \"l\"(reinterpret_cast<uint64_t const&>(a)),\n          \"l\"(reinterpret_cast<uint64_t const&>(b))\n    );\n    return c;\n}\n\n// Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul)\nCUTE_DEVICE\nfloat2 float2_mul(const float2 &a, const float2 &b) {\n    float2 c;\n    asm volatile(\n        \"mul.f32x2 %0, %1, %2;\\n\"\n        : \"=l\"(reinterpret_cast<uint64_t&>(c))\n        : \"l\"(reinterpret_cast<uint64_t const&>(a)),\n          \"l\"(reinterpret_cast<uint64_t const&>(b)));\n    return c;\n}\n\n// Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma)\nCUTE_DEVICE\nfloat2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {\n    // return a*b+c\n    float2 d;\n    asm volatile(\n        \"fma.rn.f32x2 %0, %1, %2, %3;\\n\"\n        : \"=l\"(reinterpret_cast<uint64_t&>(d))\n        : \"l\"(reinterpret_cast<uint64_t const&>(a)),\n          \"l\"(reinterpret_cast<uint64_t const&>(b)),\n          \"l\"(reinterpret_cast<uint64_t const&>(c)));\n    return d;\n}\n\n// Vectorized negation for foat32\nCUTE_DEVICE\nfloat2 float2_neg(const float2 &a) {\n    float2 t = {-1.0f, -1.0f};\n    return float2_mul(a, t);\n}\n\n// st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)\nCUTE_DEVICE\nvoid st_bulk(void* dst_ptr, int64_t size) {\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);\n    asm volatile (\n        \"st.bulk.weak.shared::cta [%0], %1, 0;\\n\"\n        :\n        : \"r\"(dst_addr), \"l\"(size)\n        : \"memory\"\n    );\n}\n\nstruct CUTE_ALIGNAS(16) CLCResponseObj {\n    // An opaque 16B value\n    char opaque[16];\n};\n\nstruct CLCResult {\n    int is_valid;\n    int x, y, z;\n};\n\n// Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)\nCUTE_DEVICE\nvoid issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) {\n    uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);\n    uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\\n\"\n        :\n        : \"r\"(response_addr), \"r\"(mbarrier_addr)\n    );\n}\n\n// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)\nCUTE_DEVICE\nvoid issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) {\n    uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);\n    uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\\n\"\n        :\n        : \"r\"(response_addr), \"r\"(mbarrier_addr)\n    );\n}\n\n// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)\n// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions\ntemplate<bool USE_LD_ACQUIRE>\nCUTE_DEVICE\nCLCResult get_clc_query_response(CLCResponseObj &response_obj) {\n    uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj);\n    CLCResult result;\n    #define EMIT_ASM(LD_MODIFIER)                                                                   \\\n        asm volatile(                                                                               \\\n            \"{\\n\"                                                                                   \\\n            \".reg .pred p1;\\n\\t\"                                                                    \\\n            \".reg .b128 clc_result;\\n\\t\"                                                            \\\n            \"ld\" LD_MODIFIER \".shared.b128 clc_result, [%4];\\n\\t\"                                   \\\n            \"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\\n\\t\"           \\\n            \"selp.u32 %3, 1, 0, p1;\\n\\t\"                                                            \\\n            \"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\\n\\t\" \\\n            \"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\\n\\t\" \\\n            \"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\\n\\t\" \\\n            \"}\\n\"                                                                                   \\\n            : \"=r\"(result.x), \"=r\"(result.y), \"=r\"(result.z), \"=r\"(result.is_valid)                 \\\n            : \"r\"(response_addr)                                                                    \\\n            : \"memory\"                                                                              \\\n        );\n    if constexpr (USE_LD_ACQUIRE) {\n        EMIT_ASM(\".acquire.cta\");\n    } else {\n        EMIT_ASM(\"\");\n    }\n    return result;\n}\n\n// LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)\n// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function\n// NC_STR should be either \"\" or \".nc\"\n// L1_CACHE_HINT_STR should be either \"evict_first\", \"evict_normal\", \"evict_last\", \"evict_unchanged\", or \"no_allocate\"\n// L2_CACHE_HINT_STR should be either \"evict_first\", \"evict_normal\", or \"evict_last\"\n// L2_PREFETCH_SIZE_STR should be either \"64B\", \"128B\", or \"256B\"\n#define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \\\n    { \\\n        static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, \"`global_addr` must be a pointer\"); \\\n        static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, \"`result` must be a pointer\"); \\\n        uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \\\n        asm volatile( \\\n            \"ld.global\" NC_STR \".L1::\" L1_CACHE_HINT_STR \".L2::\" L2_CACHE_HINT_STR \".L2::\" L2_PREFETCH_SIZE_STR \".v4.u64 {%0, %1, %2, %3}, [%4];\\n\" \\\n            : \"=l\"(result_as_uint64_ptr[0]), \"=l\"(result_as_uint64_ptr[1]), \\\n            \"=l\"(result_as_uint64_ptr[2]), \"=l\"(result_as_uint64_ptr[3]) \\\n            : \"l\"(global_addr) \\\n        ); \\\n    }\n\n// STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st)\n// L1_CACHE_HINT_STR should be either \"evict_first\", \"evict_normal\", \"evict_last\", \"evict_unchanged\", or \"no_allocate\"\n// L2_CACHE_HINT_STR should be either \"evict_first\", \"evict_normal\", or \"evict_last\"\n#define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \\\n    { \\\n        static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, \"`global_addr` must be a pointer\"); \\\n        static_assert(std::is_pointer_v<decltype(src)> || std::is_array_v<decltype(src)>, \"`src` must be a pointer\"); \\\n        uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \\\n        asm volatile( \\\n            \"st.global.L1::\" L1_CACHE_HINT_STR \".L2::\" L2_CACHE_HINT_STR \".v4.u64 [%0], {%1, %2, %3, %4};\\n\" \\\n            : \\\n            : \"l\"(global_addr), \"l\"(src_as_uint64_ptr[0]), \"l\"(src_as_uint64_ptr[1]), \\\n            \"l\"(src_as_uint64_ptr[2]), \"l\"(src_as_uint64_ptr[3]) \\\n        ); \\\n    }\n\n}\n\nnamespace kerutils {\n\n// tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)\nCUTE_DEVICE\nvoid umma_arrive_noelect(transac_bar_t &bar) {\n    uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\\n\"\n        :\n        :\"r\"(bar_intptr)\n    );\n}\n\n// tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)\nCUTE_DEVICE\nvoid umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) {\n    uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\\n\"\n        :\n        :\"r\"(bar_intptr), \"h\"(cta_mask)\n    );\n}\n\n// tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)\nCUTE_DEVICE\nvoid umma_arrive_2x1SM_noelect(transac_bar_t &bar) {\n    uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\\n\"\n        :\n        :\"r\"(bar_intptr)\n    );\n}\n\n// tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)\nCUTE_DEVICE\nvoid umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) {\n    uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);\n    asm volatile(\n        \"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\\n\"\n        :\n        :\"r\"(bar_intptr), \"h\"(cta_mask)\n    );\n}\n\n// tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)\n__device__ __forceinline__ void tcgen05_before_thread_sync() {\n    asm volatile(\"tcgen05.fence::before_thread_sync;\");\n}\n\n// tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)\n__device__ __forceinline__ void tcgen05_after_thread_sync() {\n    asm volatile(\"tcgen05.fence::after_thread_sync;\");\n}\n\n\n// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)\ntemplate <int kNumElements>\n__device__ __forceinline__\nvoid tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) {\n    uint32_t* data = (uint32_t*)data_;\n    static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, \"Invalid kNumElements\");\n    // NOTE The following code crashes VSCode intellisense engine, so we disable it\n#ifndef __VSCODE_IDE__\n    [&]<size_t... Is>(cute::index_sequence<Is...>) {\n        if constexpr (kNumElements == 1) {\n            cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 2) {\n            cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 4) {\n            cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 8) {\n            cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 16) {\n            cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 32) {\n            cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 64) {\n            cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumElements == 128) {\n            cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...);\n        }\n    }(cute::make_index_sequence<kNumElements>{});\n#endif\n}\n\n// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)\ntemplate <int kNumReplications>\n__device__ __forceinline__\nvoid tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) {\n    uint32_t* data = (uint32_t*)data_;\n    static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, \"Invalid kNumReplications\");\n#ifndef __VSCODE_IDE__\n    [&]<size_t... Is>(cute::index_sequence<Is...>) {\n        if constexpr (kNumReplications == 1) {\n            cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 2) {\n            cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 4) {\n            cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 8) {\n            cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 16) {\n            cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 32) {\n            cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 64) {\n            cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...);\n        }\n    }(cute::make_index_sequence<kNumReplications*2>{});\n#endif\n}\n\n// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)\ntemplate <int kNumReplications>\n__device__ __forceinline__\nvoid tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) {\n    uint32_t* data = (uint32_t*)data_;\n    static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, \"Invalid kNumReplications\");\n#ifndef __VSCODE_IDE__\n    [&]<size_t... Is>(cute::index_sequence<Is...>) {\n        if constexpr (kNumReplications == 1) {\n            cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 2) {\n            cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 4) {\n            cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 8) {\n            cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 16) {\n            cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...);\n        } else if constexpr (kNumReplications == 32) {\n            cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...);\n        }\n    }(cute::make_index_sequence<kNumReplications*4>{});\n#endif\n}\n\n// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)\ntemplate <int kNumElements>\n__device__ __forceinline__\nvoid tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) {\n    uint32_t const* data = (uint32_t const*)data_;\n    static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, \"Invalid kNumElements\");\n#ifndef __VSCODE_IDE__\n    [&]<size_t... Is>(cute::index_sequence<Is...>) {\n        if constexpr (kNumElements == 1) {\n            cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 2) {\n            cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 4) {\n            cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 8) {\n            cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 16) {\n            cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 32) {\n            cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 64) {\n            cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start);\n        } else if constexpr (kNumElements == 128) {\n            cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start);\n        }\n    }(cute::make_index_sequence<kNumElements>{});\n#endif\n}\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include <kerutils/device/common.h>\n\nnamespace cute {\n\n// Extensions to CuTe\n// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100.\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory\n////////////////////////////////////////////////////////////////////////////////////////////////////\nstruct SM100_TMA_2SM_LOAD_1D_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,\n       [[maybe_unused]] void      * smem_ptr,\n       [[maybe_unused]] int32_t const& crd0)\n  {\n#if defined(CUTE_ARCH_TMA_SM100_ENABLED)\n    uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);\n    // Executed by both CTAs. Set peer bit to 0 so that the\n    // transaction bytes will update CTA0's barrier.\n    uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile (\n      \"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint\"\n      \" [%0], [%1, {%3}], [%2], %4;\"\n      :\n      : \"r\"(smem_int_ptr), \"l\"(gmem_int_desc), \"r\"(smem_int_mbar),\n        \"r\"(crd0), \"l\"(cache_hint)\n      : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.\");\n#endif\n  }\n};\nstruct SM100_TMA_2SM_LOAD_2D_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,\n       [[maybe_unused]] void      * smem_ptr,\n       [[maybe_unused]] int32_t const& crd0, int32_t const& crd1)\n  {\n#if defined(CUTE_ARCH_TMA_SM100_ENABLED)\n    uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);\n    // Executed by both CTAs. Set peer bit to 0 so that the\n    // transaction bytes will update CTA0's barrier.\n    uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile (\n      \"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint\"\n      \" [%0], [%1, {%3, %4}], [%2], %5;\"\n      :\n      : \"r\"(smem_int_ptr), \"l\"(gmem_int_desc), \"r\"(smem_int_mbar),\n        \"r\"(crd0), \"r\"(crd1), \"l\"(cache_hint)\n      : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.\");\n#endif\n  }\n};\nstruct SM100_TMA_2SM_LOAD_3D_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,\n       [[maybe_unused]] void      * smem_ptr,\n       [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)\n  {\n#if defined(CUTE_ARCH_TMA_SM100_ENABLED)\n    uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);\n    // Executed by both CTAs. Set peer bit to 0 so that the\n    // transaction bytes will update CTA0's barrier.\n    uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile (\n      \"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint\"\n      \" [%0], [%1, {%3, %4, %5}], [%2], %6;\"\n      :\n      : \"r\"(smem_int_ptr), \"l\"(gmem_int_desc), \"r\"(smem_int_mbar),\n        \"r\"(crd0), \"r\"(crd1), \"r\"(crd2), \"l\"(cache_hint)\n      : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.\");\n#endif\n  }\n};\nstruct SM100_TMA_2SM_LOAD_4D_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)\n  {\n#if defined(CUTE_ARCH_TMA_SM100_ENABLED)\n    uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);\n    // Executed by both CTAs. Set peer bit to 0 so that the\n    // transaction bytes will update CTA0's barrier.\n    uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile (\n      \"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint\"\n      \" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;\"\n      :\n      : \"r\"(smem_int_ptr), \"l\"(gmem_int_desc), \"r\"(smem_int_mbar),\n        \"r\"(crd0), \"r\"(crd1), \"r\"(crd2), \"r\"(crd3), \"l\"(cache_hint)\n      : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.\");\n#endif\n  }\n};\nstruct SM100_TMA_2SM_LOAD_5D_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)\n  {\n#if defined(CUTE_ARCH_TMA_SM100_ENABLED)\n    uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);\n    // Executed by both CTAs. Set peer bit to 0 so that the\n    // transaction bytes will update CTA0's barrier.\n    uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile (\n      \"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint\"\n      \" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;\"\n      :\n      : \"r\"(smem_int_ptr), \"l\"(gmem_int_desc), \"r\"(smem_int_mbar),\n        \"r\"(crd0), \"r\"(crd1), \"r\"(crd2), \"r\"(crd3), \"r\"(crd4), \"l\"(cache_hint)\n      : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.\");\n#endif\n  }\n};\nstruct SM100_TMA_2SM_LOAD_NOSPLIT\n{\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0)\n  {\n    return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);\n  }\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1)\n  {\n    return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);\n  }\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)\n  {\n    return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);\n  }\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)\n  {\n    return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);\n  }\n  CUTE_HOST_DEVICE static void\n  copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,\n       void      * smem_ptr,\n       int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)\n  {\n    return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);\n  }\n  using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;\n};\nstruct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};\n// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar\n// Use .with(tma_mbar) to construct an executable version\ntemplate <class NumBitsPerTMA, class AuxParams_>\nstruct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>\n{\n  using ThrID     = Layout<_1>;\n  // Map from (src-thr,src-val) to bit\n  using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;\n  // Map from (dst-thr,dst-val) to bit\n  using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;\n  // Reference map from (thr,val) to bit\n  using RefLayout = SrcLayout;\n  // SM100_TMA_2SM_LOAD_NOSPLIT arguments\n  TmaDescriptor tma_desc_;\n  using AuxParams = AuxParams_;\n  AuxParams aux_params_;\n  // Return TmaDescriptor/TensorMap\n  CUTE_HOST_DEVICE constexpr\n  TmaDescriptor const*\n  get_tma_descriptor() const {\n    return &tma_desc_;\n  }\n  // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar\n  CUTE_HOST_DEVICE constexpr\n  Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>\n  with(\n    uint64_t& tma_mbar,\n    [[maybe_unused]] uint16_t const& multicast_mask = 0,\n    TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {\n    // We accept multicast_mask here to keep the API for both atoms consistent\n    return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};\n  }\n  // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)\n  CUTE_HOST_DEVICE constexpr\n  Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>\n  with(\n    TmaDescriptor const* new_tma_desc,\n    uint64_t& tma_mbar,\n    [[maybe_unused]] uint16_t const& multicast_mask = 0,\n    TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {\n    // We accept multicast_mask here to keep the API for both atoms consistent\n    return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};\n  }\n  // Generate the TMA coord tensor\n  template <class GShape>\n  CUTE_HOST_DEVICE constexpr\n  auto\n  get_tma_tensor(GShape const& g_shape) const {\n    static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);\n    return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_));\n  }\n  // Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()\n  template <class TS, class SLayout,\n            class TD, class DLayout>\n  CUTE_HOST_DEVICE friend constexpr void\n  copy_unpack(Copy_Traits        const& traits,\n              Tensor<TS,SLayout> const& src,\n              Tensor<TD,DLayout>      & dst) = delete;\n};\n// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar\ntemplate <class NumBitsPerTMA>\nstruct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>\n  : TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>\n{\n  using ThrID     = Layout<_1>;\n  // Map from (src-thr,src-val) to bit\n  using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;\n  // Map from (dst-thr,dst-val) to bit\n  using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;\n  // Reference map from (thr,val) to bit\n  using RefLayout = SrcLayout;\n  // SM100_TMA_2SM_LOAD_NOSPLIT arguments\n  tuple<\n  TmaDescriptor const*,\n  uint64_t*, // smem mbarrier\n  uint64_t   // cache hint\n  > const opargs_;\n  CUTE_HOST_DEVICE\n  Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)\n    : opargs_(desc, mbar, cache) {}\n};\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm80/helpers.cuh",
    "content": "#pragma once\n\n#include \"kerutils/device/common.h\"\n#include \"kerutils/device/sm80/intrinsics.cuh\"\n\nnamespace kerutils {\n\n// Retrieve the value of `%smid` and check its range\nCUTE_DEVICE\nuint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) {\n    uint32_t sm_id = get_sm_id();\n    if (!(sm_id < num_physical_sms)) {\n        trap();\n    }\n    return sm_id;\n}\n\n#ifndef KU_TRAP_ONLY_DEVICE_ASSERT\n#define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \\\ndo { \\\n    if (not (cond)) \\\n        asm(\"trap;\"); \\\n} while (0)\n#endif\n\n// Construct a `float2` from a single `float` by duplicating the value \nCUTE_DEVICE\nfloat2 float2float2(const float &x) {\n    return float2 {x, x};\n}\n\nCUTE_DEVICE\nvoid st_shared(void* ptr, __int128_t val) {\n    asm volatile(\"st.shared.b128 [%0], %1;\" :: \"l\"(__cvta_generic_to_shared(ptr)), \"q\"(val));\n}\n\nCUTE_DEVICE\nvoid st_shared(void* ptr, float4 val) {\n    st_shared(ptr, *(__int128_t*)&val);\n}\n\nCUTE_DEVICE\n__int128_t ld_shared(void* ptr) {\n    __int128_t val;\n    asm volatile(\"ld.shared.b128 %0, [%1];\" : \"=q\"(val) : \"l\"(__cvta_generic_to_shared(ptr)));\n    return val;\n}\n\nCUTE_DEVICE\nfloat4 ld_shared_float4(void* ptr) {\n    __int128_t temp = ld_shared(ptr);\n    return *(float4*)&temp;\n}\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh",
    "content": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async)\ntemplate<PrefetchSize PREFETCH_SIZE=PrefetchSize::B128>\nCUTE_DEVICE\nvoid cp_async_cacheglobal(const void* src, void* dst, bool pred=true) {\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);\n    if constexpr (PREFETCH_SIZE == PrefetchSize::B64) {\n        asm volatile(\"cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\\n\"\n            :: \"r\"(dst_addr),\n               \"l\"(src),\n               \"r\"(pred?16:0));\n    } else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) {\n        asm volatile(\"cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\\n\"\n            :: \"r\"(dst_addr),\n               \"l\"(src),\n               \"r\"(pred?16:0));\n    } else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) {\n        asm volatile(\"cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\\n\"\n            :: \"r\"(dst_addr),\n               \"l\"(src),\n               \"r\"(pred?16:0));\n    } else {\n        static_assert(PREFETCH_SIZE == PrefetchSize::B64 ||\n                      PREFETCH_SIZE == PrefetchSize::B128 ||\n                      PREFETCH_SIZE == PrefetchSize::B256,\n                      \"Unsupported prefetch size for cp_async_cacheglobal.\");\n    }\n}\n\n// Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy)\ntemplate<CacheHint PRIMARY_PRIORITY, CacheHint SECONDARY_PRIORITY>\nCUTE_DEVICE\nint64_t create_fraction_based_cache_policy(float fraction = 1.0f) {\n    int64_t result;\n    #define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \\\n        asm volatile( \\\n            \"createpolicy.fractional.L2::\" PRIMARY_PRIORITY_STR \".L2::\" SECONDARY_PRIORITY_STR \".b64 %0, %1;\\n\" \\\n            : \"=l\"(result) \\\n            : \"f\"(fraction) \\\n        );\n    #define EMIT2(PRIMARY_PRIORITY_STR) \\\n        { \\\n            if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \\\n                EMIT(PRIMARY_PRIORITY_STR, \"evict_first\") \\\n            } else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \\\n                EMIT(PRIMARY_PRIORITY_STR, \"evict_unchanged\") \\\n            } else { \\\n                static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \\\n                            SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \\\n                            \"Unsupported secondary cache hint for create_fraction_based_cache_policy.\"); \\\n            } \\\n        }\n    if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) {\n        EMIT2(\"evict_first\");\n    } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) {\n        EMIT2(\"evict_normal\");\n    } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) {\n        EMIT2(\"evict_last\");\n    } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) {\n        EMIT2(\"evict_unchanged\");\n    } else {\n        static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST ||\n                      PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL ||\n                      PRIMARY_PRIORITY == CacheHint::EVICT_LAST ||\n                      PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED,\n                      \"Unsupported primary cache hint for create_fraction_based_cache_policy.\");\n    }\n    #undef EMIT\n    #undef EMIT2\n    return result;\n}\n\n// Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f))\n// The same as cute::TMA::CacheHintSmXX\ntemplate<CacheHint CACHE_HINT>\nCUTE_DEVICE\nconstexpr int64_t create_simple_cache_policy() {\n    if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) {\n        return 0x12F0000000000000;  // Result of createpolicy.fractional.L2::evict_first.b64\n    } else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) {\n        return 0x1000000000000000;  // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?)\n    } else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) {\n        return 0x14F0000000000000;  // Result of createpolicy.fractional.L2::evict_last.b64\n    } else {\n        static_assert(CACHE_HINT == CacheHint::EVICT_FIRST ||\n                      CACHE_HINT == CacheHint::EVICT_NORMAL ||\n                      CACHE_HINT == CacheHint::EVICT_LAST,\n                      \"Unsupported cache hint for create_simple_cache_policy.\");\n    }\n}\n\n// AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)\nCUTE_DEVICE\nvoid atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) {\n    asm volatile(\n        \"{\\n\\t\"\n        \".reg .pred p;\\n\\t\"\n        \"setp.eq.u32 p, %3, 1;\\n\\t\"\n        \"@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \\n\\t\"\n        \"}\"\n        : \n        : \"f\"(data),\n          \"l\"((int64_t)global_addr), \"l\"(cache_policy), \"r\"(pred)\n    );\n}\n\n// Get the id of the current SM\n// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while \"The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.\". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`.\n// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT]\nCUTE_DEVICE\nuint32_t get_sm_id() {\n    uint32_t ret;\n    asm volatile(\"mov.u32 %0, %%smid;\\n\" : \"=r\"(ret));\n    return ret;\n}\n\n// trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap)\nCUTE_DEVICE\nvoid trap() {\n    asm volatile(\"trap;\\n\");\n}\n\n// LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)\n// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function\n// NC_STR should be either \"\" or \".nc\"\n// L1_CACHE_HINT_STR should be either \"evict_first\", \"evict_normal\", \"evict_last\", \"evict_unchanged\", or \"no_allocate\"\n// L2_PREFETCH_SIZE_STR should be either \"64B\", \"128B\", or \"256B\"\n// L2 cache hint is not supported since it's only supported for LDG.256\n#define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \\\n    { \\\n        static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, \"`global_addr` must be a pointer\"); \\\n        static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, \"`result` must be a pointer\"); \\\n        uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \\\n        asm volatile( \\\n            \"ld.global\" NC_STR \".L1::\" L1_CACHE_HINT_STR \".L2::\" L2_PREFETCH_SIZE_STR \".v2.u64 {%0, %1}, [%2];\\n\" \\\n            : \"=l\"(result_as_uint64_ptr[0]), \"=l\"(result_as_uint64_ptr[1]) \\\n            : \"l\"(global_addr) \\\n        ); \\\n    }\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm90/helpers.cuh",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\ntemplate<\n    typename TMA,\n    typename Tensor0,\n    typename Tensor1\n>\nCUTE_DEVICE\nvoid launch_tma_copy(\n    const TMA &tma_copy,\n    Tensor0 src,\n    Tensor1 dst,\n    transac_bar_t &bar,\n    const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL\n) {\n    auto thr_tma = tma_copy.get_slice(cute::_0{});\n    cute::copy(\n        tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),\n        thr_tma.partition_S(src),\n        thr_tma.partition_D(dst)\n    );\n}\n\n// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx\n// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a\nCUTE_DEVICE\nint get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {\n    int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);\n    return row_idx;\n}\n\n// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx\n// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a\nCUTE_DEVICE\nint get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {\n    int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);\n    return col_idx;\n}\n\ntemplate <bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\nCUTE_DEVICE\nvoid wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) {\n    using namespace cute;\n    constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;\n    // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const\n    if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n    warpgroup_fence_operand(tCrC);\n    warpgroup_arrive();\n    tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;\n    // Unroll the K mode manually to set scale D to 1\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n        cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n        tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    }\n    if constexpr (commit) {\n        warpgroup_commit_batch();\n    }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n}\n\ntemplate <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\nCUTE_DEVICE\nvoid wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {\n    using namespace cute;\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor sA_frag = thr_mma.partition_fragment_A(sA);\n    Tensor sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(sA_frag) == size<2>(sB_frag));\n\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_arrive();\n    tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k = 0; k < size<2>(sA_frag); ++k) {\n        cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);\n        tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    }\n    warpgroup_fence_operand(rC_frag);\n}\n\ntemplate <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\nCUTE_DEVICE\nvoid wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {\n    using namespace cute;\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(rA_frag) == size<2>(sB_frag));\n\n    warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_arrive();\n    tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k = 0; k < size<2>(rA_frag); ++k) {\n        cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);\n        tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    }\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));\n}\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh",
    "content": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async)\ntemplate<typename T>\nCUTE_DEVICE\nstatic void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) {\n    static_assert(sizeof(T) == 16, \"Data type must be 16 bytes (128 bits) for st_async.\");\n    long2 data_long2 = *reinterpret_cast<const long2*>(&data);\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);\n    asm volatile (\n        \"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \\n\"\n        :\n        : \"r\"(dst_addr), \"l\"(data_long2.x), \"l\"(data_long2.y), \"r\"(mbar_addr)\n    );\n}\n\nstatic constexpr int PEER_ADDR_MASK = 16777216;\n\n// Given an address in the current CTA, return the corresponding address in the peer CTA\ntemplate<typename T>\nCUTE_DEVICE\nT* get_peer_addr(const T* p) {\n    return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);\n}\n\n// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)\ntemplate<typename T>\nCUTE_DEVICE\nT* get_cta0_addr(const T* p) {\n    constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF;\n    return (T*)((int64_t)(p) & CTA0_ADDR_MASK);\n}\n\n// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk)\nCUTE_DEVICE\nvoid tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {\n    uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr);\n    asm volatile(\"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\\n\"\n                     :\n                     : \"l\"(dst_ptr), \"r\"(smem_int_ptr), \"r\"(store_bytes)\n                     : \"memory\");\n}\n\n// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)\nCUTE_DEVICE\nvoid barrier_cluster_arrive_release() {\n    asm volatile(\"barrier.cluster.arrive.release;\" : : : \"memory\");\n}\n\n// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)\nCUTE_DEVICE\nvoid barrier_cluster_arrive_relaxed() {\n    asm volatile(\"barrier.cluster.arrive.relaxed;\" : : :);\n}\n\n// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)\nCUTE_DEVICE\nvoid barrier_cluster_wait_acquire() {\n    asm volatile(\"barrier.cluster.wait.acquire;\" : : : \"memory\");\n}\n\n// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive)\nCUTE_DEVICE\nvoid mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) {\n    uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar);\n    asm volatile(\n        \"{\\n\\t\"\n        \"mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\\n\\t\"\n        \"}\"\n        :\n        : \"r\"(smem_addr));\n}\n\n// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)\nCUTE_DEVICE\nvoid atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) {\n    asm volatile(\n        \"{\\n\\t\"\n        \".reg .pred p;\\n\\t\"\n        \"setp.eq.u32 p, %6, 1;\\n\\t\"\n        \"@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \\n\\t\"\n        \"}\"\n        : \n        : \"f\"(data.x), \"f\"(data.y), \"f\"(data.z), \"f\"(data.w),\n          \"l\"((int64_t)global_addr), \"l\"(cache_policy), \"r\"(pred)\n    );\n}\n\n// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)\nCUTE_DEVICE\nvoid cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) {\n    uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr);\n    uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);\n    asm volatile(\n        \"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \\n\"\n        :\n        : \"r\"(dst_smem_addr), \"r\"(src_smem_addr), \"r\"(load_bytes), \"r\"(mbar_addr)\n    );\n}\n\n}\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/host/host.h",
    "content": "#pragma once\n\n#include <exception>\n#include <string>\n#include <sstream>\n#include <vector>\n\n#include <cuda_runtime_api.h>\n#include <cuda.h>\n\n#include <cutlass/cuda_host_adapter.hpp>\n\n#include \"kerutils/common/common.h\"\n\nnamespace kerutils {\n\nclass KUException final : public std::exception {\n    std::string message = {};\n\npublic:\n    template<typename... Args>\n    explicit KUException(const char *name, const char* file, const int line, Args&&... args) {\n        std::ostringstream oss;\n        \n        oss << name << \" error (\" << file << \":\" << line << \"): \";\n        (oss << ... << args);\n        message = oss.str();\n    }\n\n    const char *what() const noexcept override {\n        return message.c_str();\n    }\n};\n\n#define THROW_KU_EXCEPTION(name, ...) \\\n    throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__)\n\n#define KU_CUDA_CHECK(call)                                                                                  \\\ndo {                                                                                                  \\\n    cudaError_t status_ = call;                                                                       \\\n    if (status_ != cudaSuccess) {                                                                     \\\n        fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n        THROW_KU_EXCEPTION(\"CUDA\", \"CUDA error: \", cudaGetErrorString(status_));                       \\\n    }                                                                                                 \\\n} while(0)\n\n#define KU_CUTLASS_CHECK(call) \\\ndo {                                                                                                  \\\n    cutlass::Status status_ = call;                                                                   \\\n    if (status_ != cutlass::Status::kSuccess) {                                                      \\\n        fprintf(stderr, \"CUTLASS error (%s:%d): %d\\n\", __FILE__, __LINE__, static_cast<int>(status_)); \\\n        THROW_KU_EXCEPTION(\"CUTLASS\", \"CUTLASS error: \", static_cast<int>(status_));                 \\\n    }                                                                                                 \\\n} while(0)\n\n// This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not.\n#define KU_ASSERT(cond, ...)                                                                                      \\\n    do {                                                                                                  \\\n        if (not (cond)) {                                                                                 \\\n            fprintf(stderr, \"Assertion `%s` failed (%s:%d): \", #cond, __FILE__, __LINE__);          \\\n            if constexpr (sizeof(#__VA_ARGS__) > 1) {                                                \\\n                fprintf(stderr, \", \" __VA_ARGS__);                                                        \\\n            }                                                                                             \\\n            fprintf(stderr, \"\\n\");                                                                       \\\n            THROW_KU_EXCEPTION(\"Assertion\", \"Assertion `\", #cond, \"` failed.\");                          \\\n        }                                                                                                 \\\n    } while(0)\n\n#define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError())\n\ntemplate<typename T>\ninline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) {\n    return (a + b - 1) / b;\n}\n\ntemplate<typename T>\ninline __host__ __device__ constexpr T ceil(const T &a, const T &b) {\n    return (a + b - 1) / b * b;\n}\n\n// A wrapper for make_tensor_map\nstatic inline CUtensorMap make_tensor_map(\n    const std::vector<uint64_t> &size,\n    const std::vector<uint64_t> &strides,   // PAY ATTENTION: In BYTES\n    const std::vector<uint32_t> &box_size,\n    void* global_ptr,\n    CUtensorMapDataType data_type,\n    CUtensorMapSwizzle swizzle_mode,\n    CUtensorMapL2promotion l2_promotion,\n    CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n    CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,\n    const std::vector<uint32_t> &element_strides_ = {}\n) {\n    int dim = size.size();\n    KU_ASSERT(dim >= 1);\n    \n    std::vector<uint32_t> element_strides;\n    if (element_strides_.empty()) {\n        for (int i = 0; i < dim; ++i)\n            element_strides.push_back(1);\n    } else {\n        element_strides = element_strides_;\n    }\n    KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim);\n\n    CUtensorMap result;\n    CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(\n        &result,\n        data_type,\n        dim,\n        global_ptr,\n        size.data(),\n        strides.data(),\n        box_size.data(),\n        element_strides.data(),\n        interleave_mode,\n        swizzle_mode,\n        l2_promotion,\n        oob_fill\n    );\n    if (ret_code != CUresult::CUDA_SUCCESS) {\n        auto print_vector = [&](auto t, const char* fmt, const char end='\\n') {\n            for (auto elem : t) {\n                printf(fmt, elem);\n            }\n            printf(\"%c\", end);\n        };\n        fprintf(stderr, \"Failed to create tensormap\\n\");\n        fprintf(stderr, \"Dim: %d\\n\", dim);\n        printf(\"size: \"); print_vector(size, \"%lu \");\n        printf(\"strides: \"); print_vector(strides, \"%lu \");\n        printf(\"box_size: \"); print_vector(box_size, \"%u \");\n        printf(\"element_strides: \"); print_vector(element_strides, \"%u \");\n        printf(\"global ptr: 0x%lx\\n\", (int64_t)global_ptr);\n        printf(\"data_type: %d\\n\", (int)data_type);\n        printf(\"swizzle_mode: %d\\n\", (int)swizzle_mode);\n        printf(\"l2_promotion: %d\\n\", (int)l2_promotion);\n        printf(\"interleave_mode: %d\\n\", (int)interleave_mode);\n        printf(\"oob_fill: %d\\n\", (int)oob_fill);\n        KU_ASSERT(false);\n    }\n    return result;\n}\n\n// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size\ntemplate<typename T>\nstatic inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) {\n    std::vector<uint64_t> res;\n    for (auto stride : strides_in_elems) {\n        res.push_back(((uint64_t)stride) * elem_size);\n    }\n    return res;\n}\n\n}"
  },
  {
    "path": "csrc/kerutils/include/kerutils/kerutils.cuh",
    "content": "#pragma once\n\n#include \"host/host.h\"\n#include \"device/device.cuh\"\n"
  },
  {
    "path": "csrc/kerutils/include/kerutils/supplemental/torch_tensors.h",
    "content": "#pragma once\n\n#include <functional>\n\n#include <torch/python.h>\n\n#include \"kerutils/common/common.h\"\n\nnamespace kerutils {\n\n// Check whether the given tensor or optional tensor satisfies the given condition\n// If tensor_or_opt is a tensor, check_fn is applied directly\n// If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value\ntemplate<typename T>\nstatic inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function<bool(const at::Tensor&)>& check_fn) {\n    if constexpr (std::is_same<T, at::Tensor>::value) {\n        return check_fn(tensor_or_opt);\n    } else {\n        if (tensor_or_opt.has_value()) {\n            return check_fn(tensor_or_opt.value());\n        } else {\n            return true;\n        }\n    }\n}\n\n// Get the pointer of the given tensor\n// Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise\ntemplate<typename PtrT>\nstatic inline PtrT* get_tensor_ptr(const at::Tensor& tensor) {\n    if (tensor.has_storage()) {\n        return (PtrT*)tensor.data_ptr();\n    } else {\n        return nullptr;\n    }\n}\n\n// Get the pointer of the given tensor or optional tensor\n// Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise\ntemplate<typename PtrT, typename T>\nstatic inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) {\n    if constexpr (std::is_same<T, at::Tensor>::value) {\n        return get_tensor_ptr<PtrT>(tensor_or_opt);\n    } else {\n        if (tensor_or_opt.has_value()) {\n            return get_tensor_ptr<PtrT>(*tensor_or_opt);\n        } else {\n            return nullptr;\n        }\n    }\n}\n\n}\n\n// Check whether the given tensor (or optional<tensor>) is on cuda\n#define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor \" must be on CUDA\")\n\n// Check whether the given tensor (or optional<tensor>) has the given number of dimensions\n#define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor \" must have \" #ndim \" dimensions\")\n\n// Check whether the given tensor (or optional<tensor>) has the given shape\n#define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor \" must have shape (\" #__VA_ARGS__ \")\")\n\n// Check whether the given tensor (or optional<tensor>) is contiguous\n#define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor \" must be contiguous\")\n\n// Check whether the last dimention of the given tensor (or optional<tensor>)\n#define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor \" must have contiguous last dimension\")\n\n// Check whether the given tensor (or optional<tensor>) has the specified dtype\n#define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor \" must have dtype \" #target_dtype)\n"
  },
  {
    "path": "csrc/params.h",
    "content": "#pragma once\n\n#include \"cutlass/bfloat16.h\"\n\nenum class ModelType {\n    V32,\n    MODEL1\n};\n\nstruct __align__(4*8) DecodingSchedMeta {\n    int begin_req_idx, end_req_idx;     // Both inclusive\n    int begin_block_idx, end_block_idx; // Inclusive, exclusive\n    int begin_split_idx;\n    int is_first_req_splitted, is_last_req_splitted;\n    int _pad[1];\n};\nstatic constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta);\n\nstruct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams\n    using index_t = int64_t;\n\n    int b;              // batch size\n    int s_q;\n    int q_seq_per_hk;   // The number of q(s) per KV head, = h_q / h_k * s_q\n    int d, d_v;         // K/V dimension\n    int h_q, h_k;       // The number of Q/K heads\n    int num_blocks;     // Number of blocks in total\n    int q_head_per_hk;  // The number of q_head(s) per KV head, = h_q / h_k\n    bool is_causal;\n    float scale_softmax, scale_softmax_log2;\n    \n    void *__restrict__ q_ptr;\n    void *__restrict__ k_ptr;\n    void *__restrict__ o_ptr;\n    float *__restrict__ softmax_lse_ptr;\n\n    index_t q_batch_stride;\n    index_t k_batch_stride;\n    index_t o_batch_stride;\n    index_t q_row_stride;\n    index_t k_row_stride;\n    index_t o_row_stride;\n    index_t q_head_stride;\n    index_t k_head_stride;\n    index_t o_head_stride;\n\n    int *__restrict__ block_table;\n    index_t block_table_batch_stride;\n    int page_block_size;\n    int *__restrict__ seqlens_k_ptr;\n\n    DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;\n    int num_sm_parts;\n    int *__restrict__ num_splits_ptr;\n\n    int total_num_splits;\n    float *__restrict__ softmax_lseaccum_ptr;\n    float *__restrict__ oaccum_ptr;\n\n    cudaStream_t stream;\n};\n\nstruct SparseAttnDecodeParams {\n    int b, s_q;\n    int h_q, h_kv;\n    int d_qk, d_v;\n    float sm_scale, sm_scale_div_log2;\n    int num_blocks, page_block_size, topk;\n    ModelType model_type;\n\n    cutlass::bfloat16_t* __restrict__ q;   // [b, s_q, h_q, d_qk]\n    cutlass::bfloat16_t* __restrict__ kv;  // [num_blocks, page_block_size, d_qk]\n    int* __restrict__ indices;   // [b, s_q, topk]\n    int* __restrict__ topk_length;  // [b], may be nullptr\n    float* __restrict__ attn_sink;  // [h_q], may be nullptr\n\n    float* __restrict__ lse;    // [b, s_q, h_q]\n    cutlass::bfloat16_t* __restrict__ out;   // [b, s_q, h_q, d_v]\n    \n    int extra_num_blocks, extra_page_block_size, extra_topk;\n    cutlass::bfloat16_t* __restrict__ extra_kv;  // [extra_num_blocks, extra_page_block_size, d_qk]\n    int* __restrict__ extra_indices;   // [b, s_q, extra_topk]\n    int* __restrict__ extra_topk_length;  // [b], may be nullptr\n    \n    int stride_q_b, stride_q_s_q, stride_q_h_q;\n    int stride_kv_block, stride_kv_row;\n    int stride_indices_b, stride_indices_s_q;\n    int stride_lse_b, stride_lse_s_q;\n    int stride_o_b, stride_o_s_q, stride_o_h_q;\n    int stride_extra_kv_block, stride_extra_kv_row;\n    int stride_extra_indices_b, stride_extra_indices_s_q;\n    \n    cudaStream_t stream;\n    \n    // SplitKV-related parameters\n    float* __restrict__ lse_accum;  // [num_splits, s_q, h_q]\n    float* __restrict__ o_accum;    // [num_splits, s_q, h_q, d_v]\n    int stride_lse_accum_split, stride_lse_accum_s_q;\n    int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;\n    DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous\n    int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous\n    int num_sm_parts;\n};\n\nstruct CombineParams {\n    int b, s_q, h_q, d_v;\n\n    float* __restrict__ lse;    // [b, s_q, h_q]\n    void* __restrict__ out;   // [b, s_q, h_q, d_v]\n    int stride_lse_b, stride_lse_s_q;\n    int stride_o_b, stride_o_s_q, stride_o_h_q;\n\n    float* __restrict__ lse_accum;  // [num_splits, s_q, h_q]\n    float* __restrict__ o_accum;    // [num_splits, s_q, h_q, d_v]\n    int stride_lse_accum_split, stride_lse_accum_s_q;\n    int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;\n\n    DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous\n    int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous\n    int num_sm_parts;\n\n    float* attn_sink;  // [h_q], may be nullptr\n\n    cudaStream_t stream;\n};\n\nstruct GetDecodeSchedMetaParams {\n    int b;  // batch size\n    int s_q;\n    int block_size_n;\n    int fixed_overhead_num_blocks;\n\n    int topk, extra_topk;   // -1 if sparse attention (or extra topk) is disabled\n    int *__restrict__ topk_length, *__restrict__ extra_topk_length;\n\n    int *__restrict__ seqlens_k_ptr;    // Only necessary for dense attention\n\n    DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;\n    int *__restrict__ num_splits_ptr;\n    int num_sm_parts;\n\n    cudaStream_t stream;\n};\n\nstruct SparseAttnFwdParams {\n    int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;\n    float sm_scale, sm_scale_div_log2;\n\n    // Input tensors\n    cutlass::bfloat16_t* __restrict__ q;    // [s_q, h_q, d_qk]\n    cutlass::bfloat16_t* __restrict__ kv;   // [s_kv, h_kv, d_qk]\n    int* __restrict__ indices;   // [s_q, h_kv, topk]\n    float* __restrict__ attn_sink;   // [h_q], may be nullptr\n    int* __restrict__ topk_length;    // [s_q], may be nullptr\n\n    // Strides\n    int stride_q_s_q; int stride_q_h_q;\n    int stride_kv_s_kv; int stride_kv_h_kv;\n    int stride_indices_s_q; int stride_indices_h_kv;\n\n    // Output tensors\n    cutlass::bfloat16_t* __restrict__ out;   // [s_q, h_q, d_v]\n    float* __restrict__ max_logits; // [s_q, h_q]\n    float* __restrict__ lse; // [s_q, h_q]\n\n    int num_sm;\n    cudaStream_t stream;\n};\n\n// We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes.\nenum class SparseAttnFwdMode {\n    Prefill,            // Normal prefill mode\n    DecodeWithSplitKV,  // To trigger decoding mode for kernels that support both prefill and decode\n};\n\ntemplate<SparseAttnFwdMode FWD_MODE>\ninline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwdMode::DecodeWithSplitKV>::value;\n\ntemplate<SparseAttnFwdMode FWD_MODE>\nusing SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;\n"
  },
  {
    "path": "csrc/sm100/decode/head128/README.md",
    "content": "Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel"
  },
  {
    "path": "csrc/sm100/decode/head64/config.h",
    "content": "#pragma once\n\n#include \"kernel.h\"\n\n#include <cuda_fp8.h>\n#include <cutlass/barrier.h>\n#include <cute/tensor.hpp>\n\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n#include \"params.h\"\n\nnamespace sm100::decode::head64 {\n\nusing cutlass::arch::fence_view_async_shared;\nusing cutlass::arch::NamedBarrier;\nusing e8m0 = __nv_fp8_e8m0;\nusing e4m3 = cutlass::float_e4m3_t;\nusing namespace cute;\n\nenum NamedBarriers : uint32_t {\n    main_loop_sync = 0,\n    wg0_sync = 1,\n    wg0_warp02_sync = 2,\n    wg0_warp13_sync = 3,\n    everyone_sync = 4\n};\n\ntemplate<ModelType MODEL_TYPE>\nstruct KernelTemplate {\n\nstatic constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;\nstatic constexpr int D_K = D_Q;\nstatic constexpr int D_V = 512;\nstatic constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;\nstatic constexpr int D_ROPE = 64;\nstatic constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;\nstatic constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;\nstatic constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8;    // Padding is included\nstatic constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE;   // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.\nstatic_assert(D_NOPE + D_ROPE == D_Q);\nstatic_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));\n\nstatic constexpr int B_H = 64;\nstatic constexpr int B_TOPK = 64;\nstatic constexpr int NUM_BUFS = 2;\nstatic constexpr int NUM_INDEX_BUFS = 4;    // Number of buffers for indices (tma_coords) & is_token_valid & scales\nstatic constexpr int NUM_THREADS = 128*3;  // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant\nstatic constexpr float MAX_INIT_VAL = -1e30f;  // To avoid (-inf) - (-inf) = NaN\n\nstatic constexpr int D_Q_SW128 = 512;\nstatic constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;\nstatic_assert(D_Q_SW128 + D_Q_SW64 == D_Q);\nstatic constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes\n\ntemplate<\n    typename Shape_Q_SW128, typename TMA_Q_SW128,\n    typename Shape_O, typename TMA_O\n>\nstruct TmaParams {\n    Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;\n    Shape_O shape_O; TMA_O tma_O;\n    CUtensorMap tensor_map_q_sw64;  // Invalid if D_Q_SW64 == 0\n    CUtensorMap tensor_map_kv_nope;\n    CUtensorMap tensor_map_kv_rope;\n    CUtensorMap tensor_map_extra_kv_nope;\n    CUtensorMap tensor_map_extra_kv_rope;\n};\n\n// Tensor memory columns\nstruct tmem_cols {\n    //   0 ~ 256: output\n    // 256 ~ 256 + 64*D_Q/256: Q\n    // 400 ~ 464: P\n    static constexpr int O = 0;\n    static constexpr int Q = 256;\n    static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;\n    static constexpr int P = 400;\n};\n\ntemplate<int NUM_TILES>\nusing SmemLayoutQTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<NUM_TILES*64>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;\n\nusing SmemLayoutOBuf = decltype(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<D_V>>{}\n));\n\nusing SmemLayoutOBuf_TMA = decltype(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<64>>{}\n)); // A TMA tile\n\nstatic_assert(D_V == 512);\nusing SmemLayoutOAccumBuf = Layout<\n    Shape<Int<B_H>, Int<D_V>>,\n    Stride<Int<520>, _1>\t// We use stride = 520 here to avoid bank conflict\n>;\n\nusing SmemLayoutS = decltype(tile_to_shape(\n    UMMA::Layout_K_INTER_Atom<bf16>{},\n    Shape<Int<B_H>, Int<B_TOPK>>{},\n    Step<_1, _2>{}\n));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTilesTransposed_SW128 = decltype(composition(\n    SmemLayoutKTiles_SW128<NUM_TILES>{},\n    Layout<\n        Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,\n        Stride<Int<B_TOPK>, _1>\n    >{}\n));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW64_Atom<bf16>{},\n    Shape<Int<B_H>, Int<32*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW64_Atom<bf16>{},\n    Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTilesTransposed_SW64 = decltype(composition(\n    SmemLayoutKTiles_SW64<NUM_TILES>{},\n    Layout<\n        Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,\n        Stride<Int<B_TOPK>, _1>\n    >{}\n));\n\nstruct SharedMemoryPlan {\n    union {\n        struct {\n            array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;\n            bf16 q_sw64[B_H*D_Q_SW64];  // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.\n            union {\n                array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;\n                array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;\n            } o;\n        } qo;\n        struct {\n            struct {\n                array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized\n                array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode\n            } dequant[NUM_BUFS];\n            static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope\n            array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS];  // Raw (quantized) NoPE part\n        } kv;\n    } u;\n    union {\n        float4 p_exchange_buf[4][16 * B_TOPK / 4];\n        array_aligned<bf16, cosize_v<SmemLayoutS>> s;\n    } s_p;\n    CUTE_ALIGNAS(16) float rowwise_max_buf[128];\n    char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];\n    int tma_coord[NUM_INDEX_BUFS][B_TOPK];\n    e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];\n    array_aligned<uint32_t, 1> tmem_start_addr;\n    transac_bar_t bar_last_store_done;\n    transac_bar_t bar_q_tma, bar_q_utccp;\n    transac_bar_t bar_rope_ready[NUM_BUFS];\n    transac_bar_t bar_nope_ready[NUM_BUFS];\n    transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];\n    transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];\n    transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];\n};\n\nusing TiledMMA_P = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}\n)); // *2 for dual gemm\n\nusing TiledMMA_O = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}\n));\n\ntemplate<typename TmaParam>\nstatic __device__ void\nflash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params);\n\nstatic void run(const SparseAttnDecodeParams &params);\n\n};\n\n}"
  },
  {
    "path": "csrc/sm100/decode/head64/instantiations/model1.cu",
    "content": "#include \"../kernel.cuh\"\n\nnamespace sm100::decode::head64 {\n\ntemplate\nvoid run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/decode/head64/instantiations/v32.cu",
    "content": "#include \"../kernel.cuh\"\n\nnamespace sm100::decode::head64 {\n\ntemplate\nvoid run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/decode/head64/kernel.cuh",
    "content": "#include \"kernel.h\"\n\n#include <math_constants.h>\n#include <cutlass/barrier.h>\n#include <cutlass/arch/barrier.h>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cute/tensor.hpp>\n#include <cute/arch/tmem_allocator_sm100.hpp>\n\n#include \"kerutils/kerutils.cuh\"\n\n#include \"utils.h\"\n#include \"sm100/helpers.h\"\n\n#include \"config.h\"\n\nnamespace sm100::decode::head64 {\n\ntemplate<ModelType MODEL_TYPE>\ntemplate<typename TmaParam>\n__device__ void\nKernelTemplate<MODEL_TYPE>\n::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params) {\n#if defined(KERUTILS_ENABLE_SM100A)\n    const int s_q_idx = blockIdx.x;\n    const int partition_idx = blockIdx.y;\n    const int warpgroup_idx = cutlass::canonical_warp_group_idx();\n    const int idx_in_warpgroup = threadIdx.x % 128;\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n    const int lane_idx = threadIdx.x % 32;\n\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n\n    if (warp_idx == 0 && elect_one_sync()) {\n        cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64);\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);\n    }\n\n    if (warp_idx == 0) {\n        if (elect_one_sync()) {\n            plan.bar_last_store_done.init(128);\n            plan.bar_q_tma.init(1);\n            plan.bar_q_utccp.init(1);\n            for (int i = 0; i < NUM_BUFS; ++i) {\n                plan.bar_rope_ready[i].init(1);\n                plan.bar_nope_ready[i].init(128); \n                plan.bar_raw_ready[i].init(1);\n                plan.bar_raw_free[i].init(128);\n                plan.bar_qk_done[i].init(1);\n                plan.bar_so_ready[i].init(128);\n                plan.bar_sv_done[i].init(1);\n            }\n            for (int i = 0; i < NUM_INDEX_BUFS; ++i) {\n                plan.bar_valid_coord_scale_ready[i].init(32);\n                plan.bar_valid_coord_scale_free[i].init(128+128+1+1);\n            }\n            cutlass::arch::fence_barrier_init();\n        }\n        cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());\n        KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);\n        cute::TMEM::Allocator1Sm().release_allocation_lock();\n    }\n    __syncthreads();\n\n    struct MainLoopArgs {\n        int batch_idx, start_block_idx, end_block_idx;\n        bool is_no_split; int n_split_idx;\n        bool bar_phase_batch_rel;    // Bar phase of barriers that are used once per batch\n        int topk_length, extra_topk_length, num_orig_kv_blocks;\n        bool is_last_batch;\n    };\n\n    auto run_main_loop = [&](auto f) {\n        // NOTE Putting the following code outside the warpgroup specialization switch results in register spilling.\n        // [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted;\n        DecodingSchedMeta sched_meta;\n        KU_LDG_256(\n            params.tile_scheduler_metadata_ptr + partition_idx,\n            &sched_meta,\n            \".nc\",\n            \"no_allocate\",\n            \"evict_normal\",\n            \"256B\"\n        );\n\n        if (sched_meta.begin_req_idx >= params.b) {\n            return;\n        }\n        \n        bool bar_phase_batch_rel = 0;\n        #pragma unroll 1\n        for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) {\n            int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;\n            int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);\n            int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;\n            int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK);    // % B_TOPK == 0\n            int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;\n            int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;\n            bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);\n            int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);\n\n            MainLoopArgs args = {\n                batch_idx, start_block_idx, end_block_idx,\n                !is_split, n_split_idx,\n                bar_phase_batch_rel,\n                topk_length, extra_topk_length,\n                orig_topk_padded / B_TOPK,\n                batch_idx == sched_meta.end_req_idx\n            };\n\n            f(args);\n            NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned();\n        }\n    };\n\n    struct RingState {\n        int buf_idx = 0;\n        bool bar_phase = 0;\n        int index_buf_idx = 0;\n        bool index_bar_phase = 0;\n        CUTE_DEVICE void update() {\n            bar_phase ^= (buf_idx == NUM_BUFS-1);\n            buf_idx = (buf_idx+1) % NUM_BUFS;\n            index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1);\n            index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS;\n        }\n    };\n    RingState rs;\n\n    if (warpgroup_idx == 0) {\n        // Scale & Exp warpgroup\n        // The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel\n        cutlass::arch::warpgroup_reg_alloc<224>();\n\n        constexpr int B_EPI = 64;   // Must be equal to the size of the swizzle atom\n        Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{});\n        bf16* sO_bases[B_EPI/8];   // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write\n        CUTE_UNROLL\n        for (int i = 0; i < B_EPI/8; ++i)\n            sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8);\n\n        const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};\n        bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);\n\n        float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F;\n        \n        run_main_loop([&](const MainLoopArgs &args) {\n            cute::tma_store_wait<0>();\n            plan.bar_last_store_done.arrive();\n\n            float mi = MAX_INIT_VAL;\n            float li = 0.0f;\n            float real_mi = -CUDART_INF_F;\n\n            CUTE_NO_UNROLL\n            for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);  // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free\n                plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);    // Put the barrier wait here for more code reordering space\n                plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase);\n                ku::tcgen05_after_thread_sync();\n\n                // Load P\n                float p[B_TOPK/2], p_peer[B_TOPK/2];\n                if (warp_idx < 2) {\n                    ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p);\n                    ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p_peer);\n                } else {\n                    ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p_peer);\n                    ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p);\n                }\n                cutlass::arch::fence_view_async_tmem_load();\n                ku::tcgen05_before_thread_sync();\n\n                // Reduce within shared mem\n                {\n                    // Store\n                    // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part\n                    CUTE_UNROLL\n                    for (int i = 0; i < (B_TOPK/2)/4; ++i)\n                        plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4);\n                    NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1));  // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3\n                    // Load\n                    CUTE_UNROLL\n                    for (int i = 0; i < (B_TOPK/2)/4; ++i) {\n                        float2 t[2];\n                        *(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx];\n                        float2* cur_p = (float2*)(p + i*4);\n                        cur_p[0] = ku::float2_add(cur_p[0], t[0]);\n                        cur_p[1] = ku::float2_add(cur_p[1], t[1]);\n                    }\n                }\n\n                // Since dual gemm is utilized, the layout of P in register now look like:\n                // \n                //         32      32\n                //     +-------+-------+\n                //     |       |       |\n                // 32  | Warp0 | Warp2 |\n                //     |       |       |\n                //     +-------+-------+\n                //     |       |       |\n                // 32  | Warp1 | Warp3 |\n                //     |       |       |\n                //     +-------+-------+\n\n                // Mask\n                uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0));\n                CUTE_UNROLL\n                for (int i = 0; i < B_TOPK/2; i += 1) {\n                    if (!(valid_mask>>i&1))\n                        p[i] = -CUDART_INF_F;\n                }\n                \n                // Get rowwise max of Pi\n                float cur_pi_max = -CUDART_INF_F;\n                CUTE_UNROLL\n                for (int i = 0; i < (B_TOPK/2); i += 1) {\n                    cur_pi_max = max(cur_pi_max, p[i]);\n                }\n                cur_pi_max *= params.sm_scale_div_log2;\n\n                plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;\n                NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);    // This also separates \"reading p_exchange_buf\" and \"writing S\"\n                plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();\n                cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);\n                real_mi = max(real_mi, cur_pi_max);\n                bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);\n                // By this point:\n                // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)\n                // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)\n\n                // Calc scale factor, and scale li\n                float new_max, scale_for_old;\n                if (!should_scale_o) {\n                    // Don't scale O\n                    scale_for_old = 1.0f;\n                    new_max = mi;\n                } else {\n                    new_max = max(cur_pi_max, mi);\n                    scale_for_old = exp2f(mi - new_max);\n                }\n                mi = new_max;   // mi is still identical within each row\n\n                // Calculate S\n                __nv_bfloat162 s[(B_TOPK/2)/2];\n                float2 neg_new_max = float2 {-new_max, -new_max};\n                float2 cur_sum = float2 {0.0f, 0.0f};\n                CUTE_UNROLL\n                for (int i = 0; i < (B_TOPK/2)/2; i += 1) {\n                    float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale, neg_new_max);\n                    d.x = exp2f(d.x);\n                    d.y = exp2f(d.y);\n                    cur_sum = ku::float2_add(cur_sum, d);\n                    s[i] = __float22bfloat162_rn(d);\n                }\n                li = fma(li, scale_for_old, (cur_sum.x + cur_sum.y));\n\n                // Write S\n                CUTE_UNROLL\n                for (int i = 0; i < B_TOPK/2/8; i += 1) {\n                    *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);\n                }\n\n                // Scale O\n                if (block_idx != args.start_block_idx && should_scale_o) {\n                    float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; \n                    ku::tcgen05_after_thread_sync();\n\n                    static constexpr int CHUNK_SIZE = 64;\n                    float2 o[CHUNK_SIZE/2];\n                    CUTE_UNROLL\n                    for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {\n                        // Load O\n                        ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);\n                        cutlass::arch::fence_view_async_tmem_load();\n\n                        // Mult\n                        for (int i = 0; i < CHUNK_SIZE/2; ++i) {\n                            o[i] = ku::float2_mul(o[i], scale_for_old_float2);\n                        }\n\n                        // Store O\n                        ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);\n                        cutlass::arch::fence_view_async_tmem_store();\n                    }\n                    ku::tcgen05_before_thread_sync();\n                }\n                \n                fence_view_async_shared();\n                plan.bar_so_ready[rs.buf_idx].arrive();\n\n                if (block_idx != args.end_block_idx-1) {\n                    rs.update();    // Don't update rs for the last round since we want to wait for the last SV gemm\n                }\n            }\n\n            if (real_mi == -CUDART_INF_F) {\n                // real_mi == -CUDART_INF_F <=> No valid TopK indices\n                // We set li to 0 to fit the definition that li := exp(x[i] - mi)\n                li = 0.0f;\n                mi = -CUDART_INF_F;\n            }\n\n            // Exchange li\n            plan.rowwise_max_buf[idx_in_warpgroup] = li;\n            NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n            li += plan.rowwise_max_buf[idx_in_warpgroup^64];\n\n            // Store li\n            if (idx_in_warpgroup < B_H) {\n                if (args.is_no_split) {\n                    float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));\n                    cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;\n                    float* gSoftmaxLse = (float*)params.lse + args.batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + idx_in_warpgroup;\n                    *gSoftmaxLse = cur_lse;\n                } else {\n                    float cur_lse = log2f(li) + mi;\n                    float* gSoftmaxLseAccum = (float*)params.lse_accum + args.n_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + idx_in_warpgroup;\n                    *gSoftmaxLseAccum = cur_lse;\n                }\n            }\n        \n            plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase);\n            rs.update();\n            ku::tcgen05_after_thread_sync();\n\n            if (args.is_last_batch) {\n                cudaTriggerProgrammaticLaunchCompletion();\n            }\n\n            if (args.is_no_split) {\n                Tensor tma_gO = flat_divide(\n                    tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, args.batch_idx),\n                    Shape<Int<B_H>, Int<64>>{}\n                )(_, _, _0{}, _);\n                auto thr_tma = tma_params.tma_O.get_slice(_0{});\n                Tensor tma_sO = flat_divide(\n                    sO,\n                    Shape<Int<B_H>, Int<64>>{}\n                )(_, _, _0{}, _);\n\n                float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li + exp2f(attn_sink - mi));\n                float2 o_scale_float2 = {o_scale, o_scale};\n                float2 o[B_EPI/2];\n                __nv_bfloat162 o_bf16[B_EPI/2];\n                CUTE_UNROLL\n                for (int i = 0; i < (D_V/2) / B_EPI; ++i) {\n                    // Load\n                    ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n                    // Scale & Convert\n                    CUTE_UNROLL\n                    for (int j = 0; j < B_EPI/2; ++j) {\n                        o[j] = ku::float2_mul(o[j], o_scale_float2);\n                        o_bf16[j] = __float22bfloat162_rn(o[j]);\n                    }\n                    // Store\n                    int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));\n                    CUTE_UNROLL\n                    for (int j = 0; j < B_EPI / 8; ++j)\n                        *(__int128_t*)(sO_bases[j] + col_base*B_H) = *(__int128_t*)(&o_bf16[j*4]);\n                    // Sync\n                    fence_view_async_shared();\n                    NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n                    // S -> G\n                    if (warp_idx == 0 && elect_one_sync()) {\n                        cute::copy(\n                            tma_params.tma_O,\n                            thr_tma.partition_S(tma_sO(_, _, col_base/64)),\n                            thr_tma.partition_D(tma_gO(_, _, col_base/64))\n                        );\n                    }\n                    if (warp_idx == 1 && elect_one_sync()) {\n                        cute::copy(\n                            tma_params.tma_O,\n                            thr_tma.partition_S(tma_sO(_, _, col_base/64 + (D_V/4)/64)),\n                            thr_tma.partition_D(tma_gO(_, _, col_base/64 + (D_V/4)/64))\n                        );\n                    }\n                }\n                cute::tma_store_arrive();\n            } else {\n                float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li);   // Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times\n                float2 o_scale_float2 = {o_scale, o_scale};\n                constexpr int B_EPI = 64;\n                float2 o[B_EPI/2];\n                Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_accum_buf.data()), SmemLayoutOAccumBuf{});\n                CUTE_UNROLL\n                for (int i = 0; i < (D_V/2) / B_EPI; ++i) {\n                    // Load\n                    ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n                    // Scale & Convert\n                    CUTE_UNROLL\n                    for (int j = 0; j < B_EPI/2; ++j)\n                        o[j] = ku::float2_mul(o[j], o_scale_float2);\n                    // Store\n                    int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));\n                    CUTE_UNROLL\n                    for (int j = 0; j < B_EPI / 4; ++j)\n                        *(__int128_t*)&sO(idx_in_warpgroup%64, col_base + j*4) = *(__int128_t*)(&o[j*2]);\n                }\n                fence_view_async_shared();\n                NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n                if (elect_one_sync()) {\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < B_H/4; ++local_row) {\n                        int smem_row = local_row*4 + warp_idx;\n                        SM90_BULK_COPY_S2G::copy(\n                            &sO(smem_row, _0{}),\n                            (float*)params.o_accum + args.n_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + smem_row*params.stride_o_accum_h_q,\n                            D_V*sizeof(float)\n                        );\n                    }\n                    cute::tma_store_arrive();\n                }\n            }\n        });\n\n        if (warp_idx == 0) {\n            cute::TMEM::Allocator1Sm().free(0, 512);\n        }\n    } else if (warpgroup_idx == 1) {\n        cutlass::arch::warpgroup_reg_dealloc<72>();\n        const int warp_idx = cutlass::canonical_warp_idx_sync();    // Missing this leads to reg spilling\n\n        if (warp_idx == 4 && elect_one_sync()) {\n\n            // MMA Warp\n            run_main_loop([&](const MainLoopArgs &args) {\n                if (args.start_block_idx >= args.end_block_idx) {\n                    ku::trap();\n                }\n                // Issue Q (SW128) G->S\n                {\n                    Tensor gQ = tma_params.tma_Q_SW128.get_tma_tensor(tma_params.shape_Q_SW128)(_, _, s_q_idx, args.batch_idx);\n                    Tensor sQ = make_tensor(make_smem_ptr(plan.u.qo.q.data()), SmemLayoutQ_SW128{});\n                    ku::launch_tma_copy(\n                        tma_params.tma_Q_SW128,\n                        gQ,\n                        sQ,\n                        plan.bar_q_tma,\n                        TMA::CacheHintSm90::EVICT_FIRST\n                    );\n                }\n                // Issue Q (SW64) G -> S\n                if constexpr (D_Q_SW64 > 0) {\n                    cute::SM90_TMA_LOAD_5D::copy(\n                        &tma_params.tensor_map_q_sw64,\n                        (uint64_t*)&plan.bar_q_tma,\n                        (uint64_t)TMA::CacheHintSm90::EVICT_FIRST,\n                        plan.u.qo.q_sw64,\n                        0, 0, 0,\n                        s_q_idx, args.batch_idx\n                    );\n                }\n                plan.bar_q_tma.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));\n                plan.bar_q_tma.wait(args.bar_phase_batch_rel);\n                ku::tcgen05_after_thread_sync();\n                // Issue Q (SW128) UTCCP\n                {\n                    UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                        make_tensor(\n                            make_smem_ptr(plan.u.qo.q.data()),\n                            tile_to_shape(\n                                UMMA::Layout_K_SW128_Atom<bf16>{},\n                                Shape<Int<B_H*2>, Int<64>>{}  // *2 to leverage dual GEMM\n                            )\n                        )\n                    );\n                    static_assert(D_Q_SW128%128 == 0);\n                    CUTE_UNROLL\n                    for (int tile_idx = 0; tile_idx < D_Q_SW128/128; ++tile_idx) {\n                        // Each tile: 64 x (64*2) logically, 128 x 64 bf16 on TMEM\n                        CUTE_UNROLL\n                        for (int subtile_idx = 0; subtile_idx < 64/16; ++subtile_idx) {\n                            // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM\n                            SM100_UTCCP_128dp256bit_1cta::copy(\n                                sQ_desc + (tile_idx*(B_H*128) + subtile_idx*16) * 2 / 16,\n                                tmem_cols::Q + tile_idx*32 + subtile_idx*8\n                            );\n                        }\n                    }\n                }\n                // Issue Q (SW64) UTCCP\n                if constexpr (D_Q_SW64 > 0) {\n                    UMMA::SmemDescriptor sQ_SW64_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                        make_tensor(\n                            make_smem_ptr(plan.u.qo.q_sw64),\n                            tile_to_shape(\n                                UMMA::Layout_K_SW64_Atom<bf16>{},\n                                Shape<Int<B_H*2>, Int<32>>{}  // *2 to leverage dual GEMM\n                            )\n                        )\n                    );\n                    static_assert(D_Q_SW64%64 == 0);\n                    CUTE_UNROLL\n                    for (int tile_idx = 0; tile_idx < D_Q_SW64/64; ++tile_idx) {\n                        // Each tile: 64 x (32*2) logically, 128 x 32 bf16 on TMEM\n                        CUTE_UNROLL\n                        for (int subtile_idx = 0; subtile_idx < 32/16; ++subtile_idx) {\n                            // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM\n                            SM100_UTCCP_128dp256bit_1cta::copy(\n                                sQ_SW64_desc + (tile_idx*(B_H*64) + subtile_idx*16) * 2 / 16,\n                                tmem_cols::Q + (B_H*D_Q_SW128/2/128) + tile_idx*16 + subtile_idx*8\n                            );\n                        }\n                    }\n                }\n                ku::umma_arrive_noelect(plan.bar_q_utccp);\n\n                // Allocate tmem tensors\n                TiledMMA tiled_mma_P = TiledMMA_P{};\n                TiledMMA tiled_mma_O = TiledMMA_O{};\n                // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)\n                Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});\n                Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});\n                tP.data().get() = tmem_cols::P;\n                tO.data().get() = tmem_cols::O;\n\n                // Wait for UTCCP\n                plan.bar_q_utccp.wait(args.bar_phase_batch_rel);\n                ku::tcgen05_after_thread_sync();\n\n                // Mainloop\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                    if constexpr (MODEL_TYPE == ModelType::V32) {\n                        // V3.2: RoPE behaves like an extra block with size 64, so we can do RoPE first\n                        // QK RoPE\n                        plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);\n                        ku::tcgen05_after_thread_sync();\n                        Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n                            partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_ROPE/2>>{})\n                        );\n                        tQ_rope.data().get() = tmem_cols::Q_Tail;\n                        Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].rope.data()), SmemLayoutKTiles_DualGemm_SW64<2/2>{});\n                        ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);\n\n                        // QK NoPE\n                        plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);\n                        ku::tcgen05_after_thread_sync();\n                        Tensor tQ_nope = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n                            partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_NOPE/2>>{})\n                        );\n                        tQ_nope.data().get() = tmem_cols::Q;\n                        Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});\n                        ku::utcmma_ts(tiled_mma_P, tQ_nope, sK_nope, tP, false);\n                    } else {\n                        // MODEL1: RoPE is the last 64 dims within the full 512 dim, which couples with the last 64 dim from the NoPE part when performing dual GEMM. i.e.\n                        // \n                        // logical view: |0|1|2|3|4|5|6|7| (where 7 is the RoPE part)\n                        // dual gemm's view: \n                        // |0|2|4|6|\n                        // |1|3|5|7|\n                        // \n                        // So we must wait for both the NoPE and the RoPE part, and then perform dual GEMM\n                        plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);\n                        plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);\n                        ku::tcgen05_after_thread_sync();\n\n                        Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n                            partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_Q/2>>{})\n                        );\n                        tQ.data().get() = tmem_cols::Q;\n                        Tensor sK = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});\n                        ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);\n                    }\n                    ku::umma_arrive_noelect(plan.bar_qk_done[rs.buf_idx]);\n\n                    // SV\n                    plan.bar_so_ready[rs.buf_idx].wait(rs.bar_phase);\n                    ku::tcgen05_after_thread_sync();\n                    Tensor sS = make_tensor(make_smem_ptr(plan.s_p.s.data()), SmemLayoutS{});\n                    Tensor sV = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTilesTransposed_SW128<D_V/64>{});  // NOTE: For MODEL1, it \"expands\" to the RoPE part.\n                    ku::utcmma_ss(tiled_mma_O, sS, sV, tO, block_idx == args.start_block_idx);\n                    ku::umma_arrive_noelect(plan.bar_sv_done[rs.buf_idx]);\n\n                    rs.update();\n                }\n            });\n        } else if (warp_idx == 5 && elect_one_sync()) {\n            // Raw KV NoPE retrieval warp\n            run_main_loop([&](const MainLoopArgs &args) {\n                plan.bar_q_utccp.wait(args.bar_phase_batch_rel);\n                plan.bar_last_store_done.wait(args.bar_phase_batch_rel);\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                    plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);\n                    plan.bar_raw_free[rs.buf_idx].wait(rs.bar_phase^1);\n                    int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);\n                    int4 nxt_cur_indices;\n                    CUTE_UNROLL\n                    for (int row = 0; row < B_TOPK; row += 4) {\n                        if (row+4 < B_TOPK)\n                            nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);\n                        ku::tma_gather4(\n                            block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,\n                            plan.bar_raw_ready[rs.buf_idx],\n                            plan.u.kv.raw_nope[rs.buf_idx].data() + D_NOPE*row,\n                            0,\n                            cur_indices,\n                            (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                        );\n                        cur_indices = nxt_cur_indices;\n                    }\n                    plan.bar_raw_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_NOPE*sizeof(e4m3));\n                    plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();\n                    rs.update();\n                }\n            });\n        } else if (warp_idx == 6 && elect_one_sync()) {\n            // KV RoPE retrieval warp\n            run_main_loop([&](const MainLoopArgs &args) {\n                plan.bar_q_utccp.wait(args.bar_phase_batch_rel);\n                plan.bar_last_store_done.wait(args.bar_phase_batch_rel);\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                    plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);\n                    if constexpr (MODEL_TYPE == ModelType::V32) {\n                        plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase^1);\n                    } else {\n                        plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);\n                    }\n                    int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);\n                    int4 nxt_cur_indices;\n                    CUTE_UNROLL\n                    for (int row = 0; row < B_TOPK; row += 4) {\n                        if (row+4 < B_TOPK)\n                            nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);\n                        CUTE_UNROLL\n                        for (int t = 0; t < D_ROPE/(K_ROPE_SW/2); ++t) {\n                            ku::tma_gather4(\n                                block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,\n                                plan.bar_rope_ready[rs.buf_idx],\n                                plan.u.kv.dequant[rs.buf_idx].rope.data() + (K_ROPE_SW/2)*row + t*B_TOPK*(K_ROPE_SW/2),\n                                t*(K_ROPE_SW/2),\n                                cur_indices,\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                        }\n                        cur_indices = nxt_cur_indices;\n                    }\n                    plan.bar_rope_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));\n                    plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();\n                    rs.update();\n                }\n            });\n        } else if (warp_idx == 7) {\n            // Indices transformation warp\n            // Responsible for generating: TMA coordinates, scale factors, and valid masks\n            static_assert(B_TOPK == 64);\n            static constexpr int tma_coords_step_per_token = MODEL_TYPE == ModelType::V32 ? 656/TMA_K_STRIDE : 576/TMA_K_STRIDE;\n            int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE > 512\n            int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE;\n            uint8_t* k_scales_ptr =\n                MODEL_TYPE == ModelType::V32 ?\n                (uint8_t*)params.kv + D_NOPE :\n                (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);\n            uint8_t* extra_k_scales_ptr =\n                MODEL_TYPE == ModelType::V32 ?\n                (uint8_t*)params.extra_kv + D_NOPE :\n                (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);\n            \n            run_main_loop([&](const MainLoopArgs &args) {\n                int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*s_q_idx;\n                int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*s_q_idx;\n                \n                struct IsOrigBlock {};\n                struct IsExtraBlock {};\n                auto process_one_block = [&](int block_idx, auto is_extra_block_t) {\n                    static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;\n                    int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;\n                    int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;\n                    [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;\n                    uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;\n                    int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;\n\n                    int abs_pos, my_indices[2];\n                    if (!IS_EXTRA_BLOCK) {\n                        abs_pos = block_idx*B_TOPK + lane_idx*2;\n                        *(int2*)my_indices = __ldg((int2*)(indices + abs_pos));\n                    } else {\n                        abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;\n                        *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));\n                    }\n                    plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1);\n\n                    int tma_coords[2];\n                    e8m0 scales[2*NUM_SCALES_EACH_TOKEN];\n                    char valid_mask = 0;\n                    CUTE_UNROLL\n                    for (int i = 0; i < 2; ++i) {\n                        int block_idx, idx_in_block;\n                        block_idx = (unsigned int)my_indices[i] / cur_block_size;\n                        idx_in_block = (unsigned int)my_indices[i] % cur_block_size;\n                        bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));\n                        valid_mask |= is_token_valid << i;\n                        tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.\n                        if constexpr (MODEL_TYPE == ModelType::V32) {\n                            int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0;\n                            float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset));\n                            e8m0 res[4];\n                            *(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero);\n                            *(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero);\n                            if (!is_token_valid) *(uint32_t*)res = (uint32_t)0;\n                            *(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res);\n                        } else {\n                            int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding\n                            uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0;\n                            *(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = scalesx8;\n                        }\n                    }\n                    valid_mask <<= lane_idx%4*2;\n                    valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);\n                    valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);\n                    if constexpr (MODEL_TYPE == ModelType::V32) {\n                        *(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;\n                    } else {\n                        *(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales;\n                    }\n                    *(int2*)(plan.tma_coord[rs.index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;\n                    if (lane_idx%4 == 0)\n                        plan.is_token_valid[rs.index_buf_idx][lane_idx/4] = valid_mask;\n                    \n                    plan.bar_valid_coord_scale_ready[rs.index_buf_idx].arrive();\n                    rs.update();\n                };\n\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {\n                    process_one_block(block_idx, IsOrigBlock{});\n                }\n\n                CUTE_NO_UNROLL\n                for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {\n                    process_one_block(block_idx, IsExtraBlock{});\n                }\n            });\n        } else {\n            run_main_loop([&](const MainLoopArgs &args) {});\n        }\n    } else {\n        // Dequant warpgroup\n        cutlass::arch::warpgroup_reg_alloc<208>();\n\n        // 8 threads per token\n        constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = D_NOPE/(GROUP_SIZE*8);\n        int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;\n        Tensor nope0 = make_tensor(make_smem_ptr(plan.u.kv.dequant[0].nope.data()), SmemLayoutKTiles_SW128<D_NOPE/64>{});\n        bf16* nope0_base = &nope0(group_idx, idx_in_group*8);\n        bf16* nope1_base = nope0_base + (plan.u.kv.dequant[1].nope.data() - plan.u.kv.dequant[0].nope.data());\n        e4m3* raw_nope0_base = plan.u.kv.raw_nope[rs.buf_idx].data() + group_idx*D_NOPE + idx_in_group*8;\n        e4m3* raw_nope1_base = raw_nope0_base + B_H*D_NOPE;\n\n        run_main_loop([&](const MainLoopArgs &args) {\n            // plan.bar_last_store_done.wait(args.bar_phase_batch_rel); // No need to wait since the raw nope producer must wait\n            plan.bar_q_utccp.wait(args.bar_phase_batch_rel);\n\n            CUTE_NO_UNROLL\n            for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);\n                plan.bar_raw_ready[rs.buf_idx].wait(rs.bar_phase);\n                plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);\n                uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(rs.buf_idx == 0 ? nope0_base : nope1_base);\n                e4m3* raw_nope_base = rs.buf_idx == 0 ? raw_nope0_base : raw_nope1_base;\n                auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {\n                    asm volatile (\"st.weak.shared::cta.b128 [%0], %1;\\n\" \n                        : \n                        : \"r\"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), \"q\"(data)   // 2 for sizeof(bf16)\n                    );  // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS\n                };\n                auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {\n                    return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*D_NOPE + local_col_idx*(GROUP_SIZE*8));\n                };\n                // The following code suffers from a 2-way bank conflict when reading from SMEM.\n                if constexpr (MODEL_TYPE == ModelType::V32) {\n                    CUTE_UNROLL\n                    for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {\n                        int row_idx = local_row_idx*NUM_GROUPS + group_idx;\n                        bf16 scales[4];\n                        e8m0 scales_e8m0[4];\n                        *(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx];\n                        *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));\n                        *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));\n\n                        uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);\n                        CUTE_UNROLL\n                        for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { \n                            ku::nve4m3x2 data_fp8[4];\n                            ku::nvbf16x2 data_bf16[4]; \n                            *(uint64_t*)data_fp8 = cur_data_fp8x8;\n                            if (local_col_idx+1 < COLS_PER_GROUP)\n                                cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);\n                            bf16 scale = scales[local_col_idx / (D_NOPE/(GROUP_SIZE*8)/4)];\n                            CUTE_UNROLL\n                            for (int i = 0; i < 4; ++i) {\n                                data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));\n                            }\n                            st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);\n                        }\n                    }\n                } else {\n                    CUTE_UNROLL\n                    for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {\n                        int row_idx = local_row_idx*NUM_GROUPS + group_idx;\n                        bf16 scales[8];\n                        e8m0 scales_e8m0[8];\n                        *(uint64_t*)scales_e8m0 = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx];\n                        *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));\n                        *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));\n                        *(__nv_bfloat162_raw*)(scales+4) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+4));\n                        *(__nv_bfloat162_raw*)(scales+6) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+6));\n\n                        uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);\n                        CUTE_UNROLL\n                        for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {\n                            ku::nve4m3x2 data_fp8[4];\n                            ku::nvbf16x2 data_bf16[4];\n                            *(uint64_t*)data_fp8 = cur_data_fp8x8;\n                            if (local_col_idx+1 < COLS_PER_GROUP)\n                                cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);\n                            bf16 scale = scales[local_col_idx];\n                            CUTE_UNROLL\n                            for (int i = 0; i < 4; ++i) {\n                                data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));\n                            }\n                            st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);\n                        }\n                    }\n                }\n                cutlass::arch::fence_view_async_shared();\n                plan.bar_nope_ready[rs.buf_idx].arrive();\n                plan.bar_raw_free[rs.buf_idx].arrive();\n                plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();\n                rs.update();\n            }\n        });\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100 ~ sm119\");\n    }\n#endif\n}\n\ntemplate<typename Kernel, typename TmaParams>\n__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1)\nflash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {\n    Kernel::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(params, tma_params);\n}\n\ntemplate<ModelType MODEL_TYPE>\nvoid KernelTemplate<MODEL_TYPE>::run(const SparseAttnDecodeParams &params) {\n    KU_ASSERT(params.topk % B_TOPK == 0, \"topk (%d) mod B_TOPK (%d) must be 0\", params.topk, B_TOPK);\n    KU_ASSERT(params.extra_topk % B_TOPK == 0, \"extra_topk (%d) mod B_TOPK (%d) must be 0\", params.extra_topk, B_TOPK);\n    KU_ASSERT(params.h_q == B_H);\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.d_qk == D_Q);\n    KU_ASSERT(params.d_v == D_V);\n    if constexpr (MODEL_TYPE == ModelType::MODEL1) {\n        constexpr int BYTES_PER_TOKEN = D_NOPE + 2*D_ROPE + 8;\n        KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, \"Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1\");  // Each block must be contiguous\n    }\n\n    auto shape_Q_SW128 = make_shape(B_H, D_Q, params.s_q, params.b);\n    auto tma_Q_SW128 = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q),\n            make_layout(\n                shape_Q_SW128,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)\n            )\n        ),\n        SmemLayoutQ_SW128{}\n    );\n\n    auto shape_O = make_shape(B_H, D_V, params.s_q, params.b);\n    auto tma_O = cute::make_tma_copy(\n        SM90_TMA_STORE{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.out),\n            make_layout(\n                shape_O,\n                make_stride(params.stride_o_h_q, _1{}, params.stride_o_s_q, params.stride_o_b)\n            )\n        ),\n        SmemLayoutOBuf_TMA{}\n    );\n\n    CUtensorMap tensor_map_q_sw64{};\n    if constexpr (D_Q_SW64 > 0) {\n        tensor_map_q_sw64 = ku::make_tensor_map(\n            {D_Q_SW64, (uint64_t)params.h_q, D_Q_SW64/32, (uint64_t)params.s_q, (uint64_t)params.b},\n            ku::make_stride_helper(std::vector<int64_t>{params.stride_q_h_q, (int64_t)32, params.stride_q_s_q, params.stride_q_b}, sizeof(bf16)),\n            {32, B_H, D_Q_SW64/32, 1, 1},\n            (bf16*)params.q + D_Q_SW128,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B\n        );\n    }\n\n    auto get_nope_rope_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t k_batch_stride) -> std::pair<CUtensorMap, CUtensorMap> {\n        static_assert(D_NOPE%8 == 0);\n        KU_ASSERT((int64_t)k_ptr % 16 == 0, \"The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f\", is_extra?\"extra_\":\"\", k_ptr);\n        KU_ASSERT(k_batch_stride % TMA_K_STRIDE == 0, \"%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary\", is_extra?\"extra_\":\"\", k_batch_stride, TMA_K_STRIDE);\n        CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(\n            {D_NOPE/8, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},\n            {TMA_K_STRIDE},\n            {D_NOPE/8, 1},\n            k_ptr,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT64,\n            CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B\n        );  // NOTE We combine 8 float8 into 1 int64 since boxdim cannot > 256\n        CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(\n            {D_ROPE, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},\n            {TMA_K_STRIDE},\n            {K_ROPE_SW/2, 1},\n            (uint8_t*)k_ptr + (MODEL_TYPE == ModelType::V32 ? (D_NOPE+16) : D_NOPE),\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            K_ROPE_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B\n        );\n        return {tensor_map_kv_nope, tensor_map_kv_rope};\n    };\n\n    auto [tensor_map_kv_nope, tensor_map_kv_rope] = get_nope_rope_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block);\n    CUtensorMap tensor_map_extra_kv_nope{}, tensor_map_extra_kv_rope{};\n    if (params.extra_topk > 0) {\n        std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_nope_rope_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block);\n    }\n\n    TmaParams<\n        decltype(shape_Q_SW128), decltype(tma_Q_SW128),\n        decltype(shape_O), decltype(tma_O)\n    > tma_params = {\n        shape_Q_SW128, tma_Q_SW128,\n        shape_O, tma_O,\n        tensor_map_q_sw64,\n        tensor_map_kv_nope,\n        tensor_map_kv_rope,\n        tensor_map_extra_kv_nope,\n        tensor_map_extra_kv_rope\n    };\n    auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE>, decltype(tma_params)>;\n\n    constexpr size_t smem_size = sizeof(SharedMemoryPlan);\n    static_assert(smem_size < 227*1024);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    \n    // NOTE Don't use PDL because of potential compiler bugs!\n    mla_kernel<<<dim3(params.s_q, params.num_sm_parts, 1), dim3(NUM_THREADS, 1, 1), smem_size, params.stream>>>(params, tma_params);\n    KU_CHECK_KERNEL_LAUNCH();\n}\n\ntemplate<ModelType MODEL_TYPE>\nvoid run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params) {\n    KernelTemplate<MODEL_TYPE>::run(params);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/decode/head64/kernel.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::decode::head64 {\n\ntemplate<ModelType MODEL_TYPE>\nvoid run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);\n\n}\n\n"
  },
  {
    "path": "csrc/sm100/helpers.h",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cuda_bf16.h>\n#include <cuda_fp8.h>\n\n#include \"defines.h\"\n\nnamespace sm100 {\n\nusing namespace cute;\n\nCUTE_DEVICE\nint int4_max(int4 t) {\n    return max(max(t.x, t.y), max(t.z, t.w));\n}\n\nCUTE_DEVICE\nint int4_min(int4 t) {\n    return min(min(t.x, t.y), min(t.z, t.w));\n}\n\n// Convert 2x fp8_e4m3 to 2x bf16 with scaling\nCUTE_DEVICE\nnv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) {\n    // TODO Use native conversion for CUDA >= 13.1\n    float2 data_float2 = (float2)data;\n    nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2);\n    return nv_bfloat162 {\n        data_bf16x2.x * scale,\n        data_bf16x2.y * scale\n    };\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/fmha_common.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/kernel_hardware_info.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cute/tensor.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\ntemplate<typename Atom, typename TA, typename TB, typename TC>\nCUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {\n  constexpr int rA = decltype(rank(tA))::value;\n  constexpr int rB = decltype(rank(tB))::value;\n  constexpr int rC = decltype(rank(tC))::value;\n  static_assert(rA == 3 && rB == 3 && rC == 3);\n\n  CUTLASS_PRAGMA_UNROLL\n  for (int k_block = 0; k_block < size<2>(tA); k_block++) {\n    cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);\n    atom.accumulate_ = decltype(atom.accumulate_)::One;\n  }\n}\n\ntemplate<typename Atom, typename TA, typename TB, typename TC>\nCUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {\n  atom.accumulate_ = decltype(atom.accumulate_)::Zero;\n  gemm_reset_zero_acc(atom, tA, tB, tC);\n}\n\ntemplate<class Layout, class Stages = _1>\nCUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {\n    return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));\n}\n\ntemplate<class T>\nCUTE_DEVICE T warp_uniform(T a) {\n  return __shfl_sync(0xffffffff, a, 0);\n}\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>\nCUTE_HOST_DEVICE constexpr\nauto\nto_tiled_mma_sm100_ts(\n    TiledMMA<MMA_Atom<\n      MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,\n                    cute::C<M>, cute::C<N>,\n                    cute::integral_constant<UMMA::Major, a_major>,\n                    cute::integral_constant<UMMA::Major, b_major>,\n                    cute::integral_constant<UMMA::ScaleIn, a_neg>,\n                    cute::integral_constant<UMMA::ScaleIn, b_neg>>,\n      TAs...>, TMs...>) {\n\n  return TiledMMA<MMA_Atom<\n    MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg, UMMA::Saturate::False>>,\n    TAs...>, TMs...>{};\n}\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>\nCUTE_HOST_DEVICE constexpr\nauto\nto_tiled_mma_sm100_ts(\n    TiledMMA<MMA_Atom<\n      SM100_MMA_F16BF16_SS<a_type, b_type, c_type,\n                    M, N,\n                    a_major,\n                    b_major,\n                    a_neg,\n                    b_neg>,\n      TAs...>, TMs...>) {\n  return TiledMMA<MMA_Atom<\n    SM100_MMA_F16BF16_TS<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg, UMMA::Saturate::False>,\n    TAs...>, TMs...>{};\n}\n\ntemplate<uint32_t RegCount>\nCUTLASS_DEVICE\nvoid warpgroup_reg_set() {\n  if constexpr (RegCount < 128) {\n    cutlass::arch::warpgroup_reg_dealloc<RegCount>();\n  }\n  else {\n    cutlass::arch::warpgroup_reg_alloc<RegCount>();\n  }\n}\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/fmha_fusion.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n\n#include \"cutlass/cutlass.h\"\n#include \"cute/tensor.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\nstruct NoMask {\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    return ceil_div(get<1>(problem_size), get<1>(tile_shape));\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_masked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    return 0;\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_unmasked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    return get_trip_count(blk_coord, tile_shape, problem_size);\n  }\n\n  template<class AccQK, class IndexQK, class ProblemSize>\n  CUTLASS_DEVICE\n  void apply_mask(\n      AccQK& acc_qk,\n      IndexQK const& index_qk,\n      ProblemSize const& problem_size) {\n\n    return;\n  }\n};\n\nstruct ResidualMask : NoMask {\n\n  using Base = NoMask;\n\n  template <class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE int get_masked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    if (get<1>(problem_size) % get<1>(tile_shape) != 0) {\n      return 1;\n    }\n    return 0;\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_unmasked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    // if the sequence length does not divide the tile size evenly\n    if (get<1>(problem_size) % get<1>(tile_shape) != 0) {\n      return get_trip_count(blk_coord, tile_shape, problem_size) - 1;\n    }\n    return get_trip_count(blk_coord, tile_shape, problem_size);\n  }\n\n  template<class AccQK, class IndexQK, class ProblemSize>\n  CUTLASS_DEVICE\n  void apply_mask(\n      AccQK& acc_qk,\n      IndexQK const& index_qk,\n      ProblemSize const& problem_size) {\n\n    // This is useful is seqlen_k % kBlockN != 0 since it masks\n    // the remaining elements out from softmax.\n    // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar\n    // issues as they are transparently taken care of by TMA and the\n    // epilogue, if it is instantiated with predication support.\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(acc_qk); i++) {\n      auto pos = index_qk(i);\n      if (get<1>(pos) >= get<1>(problem_size)) {\n        acc_qk(i) = -INFINITY;\n      }\n    }\n  }\n};\n\nstruct ResidualMaskForBackward : NoMask {\n\n  using Base = NoMask;\n\n  template <class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE int get_masked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    if (get<1>(problem_size) % get<1>(tile_shape) != 0) {\n      return 1;\n    }\n    return 0;\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_unmasked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    // if the sequence length does not divide the tile size evenly\n    if (get<1>(problem_size) % get<1>(tile_shape) != 0) {\n      return get_trip_count(blk_coord, tile_shape, problem_size) - 1;\n    }\n    return get_trip_count(blk_coord, tile_shape, problem_size);\n  }\n\n  template<class AccQK, class IndexQK, class ProblemSize>\n  CUTLASS_DEVICE\n  void apply_mask(\n      AccQK& acc_qk,\n      IndexQK const& index_qk,\n      ProblemSize const& problem_size) {\n\n    // This is useful is seqlen_k % kBlockN != 0 since it masks\n    // the remaining elements out from softmax.\n    // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar\n    // issues as they are transparently taken care of by TMA and the\n    // epilogue, if it is instantiated with predication support.\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(acc_qk); i++) {\n      auto pos = index_qk(i);\n      if (! elem_less(pos, select<0,1>(problem_size))) {\n        acc_qk(i) = -INFINITY;\n      }\n    }\n  }\n};\n\n// There are two ways to do causal if N_Q != N_K\n// (1) The Q is at the beginning of the matrix\n// (2) The Q is at the end of the matrix\ntemplate<bool kIsQBegin = true>\nstruct CausalMask : NoMask {\n\n  using Base = NoMask;\n\n  static constexpr bool IsQBegin = kIsQBegin;\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    // See note below on different ways to think about causal attention\n    // Again, we'd add the offset_q into the max_blocks_q calculation\n    int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);\n    if constexpr (IsQBegin) {\n      int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));\n      return std::min(max_blocks_k, max_blocks_q);\n    } else {\n      const int offset_q = get<1>(problem_size) - get<0>(problem_size);\n      int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));\n      return std::min(max_blocks_k, max_blocks_q);\n    }\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_masked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);\n    if constexpr (IsQBegin) {\n      return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));\n    } else {\n      const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;\n      return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);\n    }\n  }\n\n  template<class BlkCoord, class TileShape, class ProblemSize>\n  CUTLASS_DEVICE\n  int get_unmasked_trip_count(\n      BlkCoord const& blk_coord,\n      TileShape const& tile_shape,\n      ProblemSize const& problem_size) {\n\n    return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);\n  }\n\n  template<class AccQK, class IndexQK, class ProblemSize>\n  CUTLASS_DEVICE\n  void apply_mask(\n      AccQK& acc_qk,\n      IndexQK const& index_qk,\n      ProblemSize const& problem_size) {\n\n    // There are two ways to do causal if N_Q != N_K\n    // (1) is to assume that the Q is at the beginning of the matrix\n    //    - this is the default setting.\n    // (2) is that it is at the end of the matrix\n    //    - this is usually what we want for inference settings\n    //      where we only compute the next row and use cache for the rest\n    //    - if you'd like this, you only need to set kIsQBegin=false\n\n    if constexpr (IsQBegin) {\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(acc_qk); i++) {\n        auto pos = index_qk(i);\n        if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {\n          acc_qk(i) = -INFINITY;\n        }\n      }\n    } else {\n      const auto offset_q = get<1>(problem_size) - get<0>(problem_size);\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(acc_qk); i++) {\n        auto pos = index_qk(i);\n        if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {\n          acc_qk(i) = -INFINITY;\n        }\n      }\n    }\n  }\n};\n\ntemplate<bool kIsQBegin = true>\nstruct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {\n\n  using Base = CausalMask<kIsQBegin>;\n\n  template<class AccQK, class IndexQK, class ProblemSize>\n  CUTLASS_DEVICE\n  void apply_mask(\n      AccQK& acc_qk,\n      IndexQK const& index_qk,\n      ProblemSize const& problem_size) {\n\n    // There are two ways to do causal if N_Q != N_K\n    // (1) is to assume that the Q is at the beginning of the matrix\n    //    - this is what we demonstrate here\n    // (2) is that it is at the end of the matrix\n    //    - this is usually what we want for inference settings\n    //      where we only compute the next row and use cache for the rest\n    //    - if you'd like this, you only need to add an offset like so:\n    //      get<0>(pos) + offset_q < get<1>(pos)\n    int offset_q = 0;\n    if constexpr (!kIsQBegin) {\n      offset_q = get<1>(problem_size) - get<0>(problem_size);\n    }\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(acc_qk); i++) {\n      auto pos = index_qk(i);\n      bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);\n      if (masked) {\n        acc_qk(i) = -INFINITY;\n      }\n    }\n  }\n\n};\n\nstruct VariableLength {\n  int max_length;\n  int* cumulative_length = nullptr;\n  int total_length = -1;\n\n  CUTE_HOST_DEVICE operator int() const {\n    return max_length;\n  }\n};\n\ntemplate<class T> struct is_variable_length_impl : std::false_type {};\ntemplate<> struct is_variable_length_impl<VariableLength> : std::true_type {};\ntemplate<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;\n\ntemplate<class Shape, class Idx>\nCUTE_HOST_DEVICE\nconstexpr auto\napply_variable_length(Shape const& shape, Idx const& idx) {\n  return transform_leaf(shape, [&](auto const& s) {\n    if constexpr (is_variable_length_v<decltype(s)>) {\n      return s.cumulative_length[idx+1] - s.cumulative_length[idx];\n    }\n    else {\n      return s;\n    }\n  });\n}\n\ntemplate<class Shape, class Coord, class Idx>\nCUTE_HOST_DEVICE\nconstexpr auto\napply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {\n  auto new_shape = apply_variable_length(shape, idx);\n  auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {\n    if constexpr (is_variable_length_v<decltype(s)>) {\n      return cute::make_tuple(c, s.cumulative_length[idx]);\n    }\n    else {\n      return c;\n    }\n  });\n  return cute::make_tuple(new_shape, new_coord);\n}\n\ntemplate<class Shape, class Coord>\nCUTE_HOST_DEVICE\nconstexpr auto\napply_variable_length_offset(Shape const& shape, Coord const& coord) {\n  auto idx = back(back(coord));\n  auto result_shape = transform_leaf(shape, [&](auto const& s) {\n    if constexpr (is_variable_length_v<decltype(s)>) {\n      return s.cumulative_length[idx+1] - s.cumulative_length[idx];\n    }\n    else {\n      return s;\n    }\n  });\n  auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {\n    if constexpr (is_variable_length_v<decltype(s)>) {\n      return s.cumulative_length[idx];\n    }\n    else {\n      return _0{};\n    }\n  });\n  return cute::make_tuple(result_shape, result_offset);\n}\n\n}  // namespace cutlass::fmha::collective\n\nnamespace cute {\n\ntemplate<>\nstruct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};\n\nCUTE_HOST_DEVICE\nvoid print(cutlass::fmha::collective::VariableLength a) {\n  printf(\"Varlen<%d, %p>\", a.max_length, a.cumulative_length);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cute/layout.hpp\"\n#include \"cutlass/epilogue/collective/collective_builder.hpp\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n\nnamespace cutlass::fmha::collective {\n\ntemplate<\n  class Element,\n  class ElementAcc,\n  class TileShape,  // Q, D, _\n  class StrideO,    // Q, D, B\n  class StrideLSE_,   // Q, B\n  class OrderLoadEpilogue = cute::false_type\n>\nstruct Sm100FmhaFwdEpilogueTmaWarpspecialized {\n    \n  using Pipeline = cutlass::PipelineAsync<2>;\n\n//  using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));\n  using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<\n        cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());\n//  using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));\n  using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));\n  using SmemLayoutO_ = SmemLayoutO;\n  using StrideLSE = StrideLSE_;\n  using ElementOut = Element;\n\n  static const int NumWarpsEpilogue = 1;\n  static const int NumWarpsLoad = 1;\n\n  struct TensorStorage {\n\n    using SmemLayoutO = SmemLayoutO_;\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;\n\n  };\n\n  struct Arguments {\n    Element* ptr_O;\n    StrideO dO;\n\n    ElementAcc* ptr_LSE;\n    StrideLSE dLSE;\n  };\n\n  using TMA_O = decltype(make_tma_copy(\n    SM90_TMA_STORE{},\n    make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),\n    SmemLayoutO{}(_,_,_0{})\n  ));\n\n\n  struct Params {\n    TMA_O tma_store_o;\n\n    ElementAcc* ptr_LSE;\n    StrideLSE dLSE;\n  };\n\n  // FMHA and MLA have different input ProblemShapes; \n  // get problem_shape_O according to the input ProblemShape.\n  template<class ProblemShape>\n  CUTLASS_DEVICE static constexpr\n  auto get_problem_shape_O (\n    ProblemShape const& problem_shape) {\n    if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {\n      return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));\n    } else {\n      return select<0,2,3>(problem_shape);\n    }\n  }\n\n  template<class ProblemShape>\n  static Params to_underlying_arguments(\n      ProblemShape const& problem_shape,\n      Arguments const& args,\n      void* workspace = nullptr) {\n\n    auto ptr_O = args.ptr_O;\n    StrideO dO = args.dO;\n\n    auto problem_shape_O = get_problem_shape_O(problem_shape);\n\n    if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {\n      auto cumulative_length_q = get<0>(problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr) {\n          int max_length_q = get<0>(problem_shape).max_length;\n          get<0>(problem_shape_O).max_length = max(1, max_length_q);\n          // for variable sequence lenght, the batch is in units of row_stride\n          get<2,1>(dO) = get<0>(dO);\n          get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O)));\n          // offset ptr by the amount we add back in later\n          ptr_O -= max_length_q * get<0>(dO);\n      }\n    } else {\n      get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O));\n    }\n\n    auto tma_store_o = make_tma_copy(\n      SM90_TMA_STORE{},\n      make_tensor(ptr_O, problem_shape_O, dO),\n      SmemLayoutO{}(_,_,_0{})\n    );\n\n    return {\n      tma_store_o,\n      args.ptr_LSE,\n      args.dLSE\n    };\n  }\n\n  CUTLASS_DEVICE\n  static void prefetch_tma_descriptors(Params const& params) {\n    cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());\n  }\n\n  const Params& params;\n\n  CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}\n\n  template<class BlkCoord, class ProblemShape, class ParamsProblemShape>\n  CUTLASS_DEVICE auto\n  store(\n      BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,\n      Params const& params, ParamsProblemShape const& params_problem_shape,\n      TensorStorage& shared_storage,\n      Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {\n\n    BlkCoord blk_coord = blk_coord_in;\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    using X = Underscore;\n\n    int o0_index = 2 * get<0>(blk_coord);\n    int o1_index = 2 * get<0>(blk_coord) + 1;\n\n    Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));\n    // offset mode 0 by (max_length - real_length)\n    // offset mode 3,1 by cumulative_length + real_length\n    // the ptr is already offset by - max_length\n    // so in total this achieves \n    int offs_0 = 0;\n    int offs_2_1 = 0;\n\n    if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n      auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr) {\n        int max_length_q = get<0>(params_problem_shape).max_length;\n        offs_0 = max_length_q - get<0>(problem_shape);\n        offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);\n        get<2,1>(blk_coord) = 0;\n      }\n    }\n\n    Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);\n\n    Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});\n    Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));\n    Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});\n    auto block_tma = params.tma_store_o.get_slice(0);\n    Tensor tOsO = block_tma.partition_S(sO);\n    Tensor tOgO = block_tma.partition_D(gO);\n\n    auto pipeline_release_state = pipeline_consumer_state;\n\n    // O1 O2\n    // one pipeline: O\n    // wait from corr, issue tma store on smem\n    pipeline.consumer_wait(pipeline_consumer_state);\n    ++pipeline_consumer_state;\n\n    if (lane_predicate) {\n      copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));\n    }\n    tma_store_arrive();\n\n    pipeline.consumer_wait(pipeline_consumer_state);\n    ++pipeline_consumer_state;\n\n    if (lane_predicate) {\n      copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));\n    }\n    tma_store_arrive();\n\n    tma_store_wait<1>();\n\n    pipeline.consumer_release(pipeline_release_state);\n    ++pipeline_release_state;\n\n    tma_store_wait<0>();\n\n    if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {\n      cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, \n                                          cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n    }\n\n    pipeline.consumer_release(pipeline_release_state);\n    ++pipeline_release_state;\n\n  } \n\n};\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cute/arch/simd_sm100.hpp\"\n#include \"cute/tensor.hpp\"\n#include \"cute/layout.hpp\"\n\n#include \"../collective/fmha_common.hpp\"\n#include \"../collective/fmha_fusion.hpp\"\n#include \"../collective/sm100_fmha_load_tma_warpspecialized.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\ntemplate<\n  class Element_,\n  class ElementQK_,\n  class ElementPV_,\n  class TileShape_,\n  class StrideQ_,\n  class StrideK_,\n  class StrideV_,\n  class Mask_,\n  // shape here is QG K H\n  // and referes to the two softmax warps\n  // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)\n  // (1, 2, 1) means they sit side by side (best for small Q / large K)\n  class ThreadShape = Shape<_2, _1, _1>,\n  // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.\n  class OrderLoadEpilogue = cute::false_type\n>\nstruct Sm100FmhaFwdMainloopTmaWarpspecialized {\n\n  using Element = Element_;\n  using ElementQK = ElementQK_;\n  using ElementPV = ElementPV_;\n  using TileShape = TileShape_;\n  using StrideQ = StrideQ_;\n  using StrideK = StrideK_;\n  using StrideV = StrideV_;\n  using Mask = Mask_;\n\n  static constexpr int StageCountQ = 2;\n  static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3;\n\n  using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;\n  using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;\n  \n  using ClusterShape = Shape<_1, _1, _1>;\n\n  static const int Alignment = 128 / sizeof_bits_v<Element>;\n\n  using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{}));\n\n  using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));\n\n  using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, StrideQ, Alignment,\n      Element, StrideK, Alignment,\n      ElementQK,\n      TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,\n      cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;\n\n  using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // the stride for A does not matter since we do not load from smem at all\n      Element, StrideK, Alignment,\n      Element, decltype(select<1,0,2>(StrideV{})), Alignment,\n      ElementPV,\n      TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,\n      cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;\n\n  using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));\n  using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));\n  using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));\n\n  // Reuse shared memory for V and O.\n  static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;\n  struct TensorStorage {\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n    union {\n      cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n      cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n    };\n  };\n\n  enum class TmemAllocation : uint32_t {\n    kSizeS = 128,\n    kSizeO = 128,\n    kSizeP = 32,\n    S0 = 0,\n    S1 = S0 + kSizeS,\n    V0 = S0,  // stats storage from softmax to correction\n    V1 = S1,\n    P0 = S0 + kSizeP,\n    P1 = S1 + kSizeP,\n    O0 = S1 + kSizeS,\n    O1 = O0 + kSizeO,\n    kEnd = O1 + kSizeO\n  };\n\n  // indices for V0 / V1\n  enum : int {\n    kIdxOldRowMax = 0,\n    kIdxNewRowMax = 1,\n    kIdxFinalRowSum = 0,\n    kIdxFinalRowMax = 1\n  };\n\n  // from load to mma warp, protects q in smem\n  using PipelineQ = cutlass::PipelineTmaUmmaAsync<\n    StageCountQ,\n    typename CollectiveMmaQK::AtomThrShapeMNK\n  >;\n\n  // from load to mma warp, protects k/v in smem\n  using PipelineKV = cutlass::PipelineTmaUmmaAsync<\n    StageCountKV,\n    typename CollectiveMmaQK::AtomThrShapeMNK\n  >;\n\n  // from mma to softmax0/1 warp, protects S in tmem\n  // (not sure yet about the reverse direction)\n  // there is one pipe per softmax warp, and the mma warp alternates between them\n  using PipelineS = cutlass::PipelineUmmaAsync<1>;\n\n  // from softmax0/1/ to correction wg\n  using PipelineC = cutlass::PipelineAsync<1>;\n\n  // from mma to correction\n  using PipelineO = cutlass::PipelineUmmaAsync<2>;\n\n  // from corr to epilogue\n  using PipelineE = cutlass::PipelineAsync<2>;\n\n  using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<\n    /*stages*/ 1, /*groups*/ 2>;\n\n  static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);\n\n  static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);\n  static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);\n\n  static_assert(TransactionBytesLoadK == TransactionBytesLoadV, \"K and V smem layouts must be of equal size\");\n\n  using Load = Sm100FmhaLoadTmaWarpspecialized<\n    Element, StrideQ, StrideK, StrideV,\n    CollectiveMmaQK, CollectiveMmaPV,\n    SmemLayoutQ, SmemLayoutK, SmemLayoutV,\n    TensorStorage, PipelineQ, PipelineKV, Mask, TileShape\n  >;\n\n  struct Arguments {\n    typename Load::Arguments load;\n\n    // if zero, defaults to 1/sqrt(D)\n    float scale_softmax = 0.0f;\n\n    // scaling factors to dequantize QKV\n    float scale_q = 1.0f;\n    float scale_k = 1.0f;\n    float scale_v = 1.0f;\n\n    // scaling factor to quantize O\n    float inv_scale_o = 1.0f;\n  };\n\n  struct Params {\n    typename Load::Params load;\n\n    float scale_softmax;\n    float scale_softmax_log2;\n\n    float scale_output;\n  };\n\n  template<class ProblemShape>\n  static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {\n    return true;\n  }\n\n  template<class ProblemShape>\n  static Params to_underlying_arguments(\n      ProblemShape const& problem_shape,\n      Arguments const& args,\n      void* workspace) {\n\n    float scale_softmax = args.scale_softmax;\n    if (scale_softmax == 0.0f) {\n      scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape));\n    }\n    float log2_e = static_cast<float>(std::log2(std::exp(1.0)));\n\n    return Params{\n        Load::to_underlying_arguments(problem_shape, args.load, workspace),\n        args.scale_q * args.scale_k * scale_softmax,\n        args.scale_q * args.scale_k * log2_e * scale_softmax,\n        args.scale_v * args.inv_scale_o\n    };\n  }\n\n  CUTLASS_DEVICE\n  static void prefetch_tma_descriptors(Params const& params) {\n      Load::prefetch_tma_descriptors(params.load);\n  }\n\n  template<class BlkCoord, class ProblemShape, class ParamsProblemShape>\n  CUTLASS_DEVICE void\n  load(\n      BlkCoord const& blk_coord, ProblemShape const& problem_shape,\n      Params const& params, ParamsProblemShape const& params_problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {\n\n    Load load;\n    load.load(blk_coord, problem_shape, params.load, params_problem_shape,\n        storage,\n        pipeline_q, pipeline_q_producer_state,\n        pipeline_kv, pipeline_kv_producer_state);\n  }\n\n  template<class BlkCoord, class ProblemShape>\n  CUTLASS_DEVICE auto\n  mma(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,\n      PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,\n      PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,\n      PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {\n\n    auto pipeline_q_release_state = pipeline_q_consumer_state;\n    auto pipeline_kv_release_state = pipeline_kv_consumer_state;\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    typename CollectiveMmaQK::TiledMma mma_qk;\n    ThrMMA thr_mma_qk = mma_qk.get_slice(0);\n\n    typename CollectiveMmaPV::TiledMma mma_pv;\n    TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);\n    ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);\n\n    Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});\n    Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});\n    Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});\n\n    Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);\n    Tensor tSrK = thr_mma_qk.make_fragment_B(sK);\n    Tensor tOrV = thr_mma_pv.make_fragment_B(sV);\n\n    // tmem layout is\n    // S0 S1`O0 O1\n    // sequential in memory, where S overlaps with P and V\n\n    Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));\n    Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));\n\n    Tensor tStS0 = tStS;\n    tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);\n    Tensor tStS1 = tStS;\n    tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);\n\n    Tensor tOtO0 = tOtO;\n    tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);\n    Tensor tOtO1 = tOtO;\n    tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);\n\n    Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});\n    Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{});  // slice out staging\n\n    Tensor tOrP0 = tOrP;\n    tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);\n    Tensor tOrP1 = tOrP;\n    tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);\n\n    int k_index = 0;\n    int v_index = 0;\n    int q_index = 0;\n\n    // wait for Q1\n    q_index = pipeline_q_consumer_state.index();\n    pipeline_q.consumer_wait(pipeline_q_consumer_state);\n    ++pipeline_q_consumer_state;\n\n    Tensor tSrQ0 = tSrQ(_,_,_,q_index);\n\n\n    // wait for K1\n    k_index = pipeline_kv_consumer_state.index();\n    pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n    ++pipeline_kv_consumer_state;\n\n    // gemm Q1 * K1 -> S1\n    pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n    gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);\n\n    pipeline_s0.producer_commit(pipeline_s0_producer_state);\n    ++pipeline_s0_producer_state;\n\n    // release K1\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n    }\n\n    // wait for Q2\n    if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {\n      q_index = pipeline_q_consumer_state.index();\n      pipeline_q.consumer_wait(pipeline_q_consumer_state);\n      ++pipeline_q_consumer_state;\n    }\n\n    Tensor tSrQ1 = tSrQ(_,_,_,q_index);\n\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      k_index = pipeline_kv_consumer_state.index();\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n    }\n\n    pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n    // gemm Q2 * K1 -> S2\n    gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);\n\n    pipeline_s1.producer_commit(pipeline_s1_producer_state);\n    ++pipeline_s1_producer_state;\n\n    // release K1\n    pipeline_kv.consumer_release(pipeline_kv_release_state);\n    ++pipeline_kv_release_state;\n\n    // wait for V1\n    v_index = pipeline_kv_consumer_state.index();\n    pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n    ++pipeline_kv_consumer_state;\n\n    // this acquire returns the ownership of all of S0 to the mma warp\n    // including the P0 part\n    // acquire corr first to take it out of the critical\n    // path since softmax takes longer\n    pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n    pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n    // gemm P1 * V1 -> O1\n    gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);\n\n    pipeline_corr.producer_commit(pipeline_corr_producer_state);\n    ++pipeline_corr_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n    }\n\n    mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;\n\n    // loop:\n    mask_tile_count -= 1;\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      // wait for Ki\n      k_index = (pipeline_kv_consumer_state.index());\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n\n      // gemm Q1 * Ki -> S1\n      gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);\n\n      pipeline_s0.producer_commit(pipeline_s0_producer_state);\n      ++pipeline_s0_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        pipeline_kv.consumer_release(pipeline_kv_release_state);\n        ++pipeline_kv_release_state;\n      }\n\n      // gemm P2 * V(i-1) -> O2\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        v_index = pipeline_kv_consumer_state.index();\n        pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n        ++pipeline_kv_consumer_state;\n      }\n\n      pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n      pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n      gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);\n\n      pipeline_corr.producer_commit(pipeline_corr_producer_state);\n      ++pipeline_corr_producer_state;\n\n      // release V(i-1)\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        k_index = (pipeline_kv_consumer_state.index());\n        pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n        ++pipeline_kv_consumer_state;\n      }\n\n      // gemm Q2 * Ki -> S2\n      gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);\n\n      pipeline_s1.producer_commit(pipeline_s1_producer_state);\n      ++pipeline_s1_producer_state;\n\n      // release Ki\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n\n      // wait for Vi\n      v_index = (pipeline_kv_consumer_state.index());\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n\n      // gemm P1 * Vi -> O1\n      pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n\n      pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n      gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);\n\n      pipeline_corr.producer_commit(pipeline_corr_producer_state);\n      ++pipeline_corr_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        pipeline_kv.consumer_release(pipeline_kv_release_state);\n        ++pipeline_kv_release_state;\n      }\n    }\n\n    // release Q1\n    pipeline_q.consumer_release(pipeline_q_release_state);\n    ++pipeline_q_release_state;\n\n    // release Q2\n    if constexpr (get<0>(ThreadShape{}) > 1) {\n      pipeline_q.consumer_release(pipeline_q_release_state);\n      ++pipeline_q_release_state;\n    }\n\n    // wait for Vi\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      v_index = pipeline_kv_consumer_state.index();\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n    }\n\n    // gemm P2 * Vi -> O2\n    pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n    pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n    gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);\n\n    pipeline_corr.producer_commit(pipeline_corr_producer_state);\n    ++pipeline_corr_producer_state;\n\n    // release Vi\n    pipeline_kv.consumer_release(pipeline_kv_release_state);\n    ++pipeline_kv_release_state;\n\n    pipeline_s0.producer_commit(pipeline_s0_producer_state);\n    ++pipeline_s0_producer_state;\n\n    pipeline_s1.producer_commit(pipeline_s1_producer_state);\n    ++pipeline_s1_producer_state;\n\n    // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...\n    // Q1 * K1  , Q2 * K1  , S11 * V1 , Q1 * K2  , S21 * V1  , Q2 * K2 , S12 * V2 , Q1 * K3  , S22 * K2 , ...\n  }\n\n  template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>\n  CUTLASS_DEVICE auto\n  softmax_step(\n      float& row_max, float& row_sum,\n      Stage stage, bool final_call,\n      BlkCoord const& blk_coord, CoordTensor const& cS,\n      Params const& params, ProblemShape const& problem_shape,\n      PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,\n      PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,\n      OrderBarrierSoftmax& order_s) {\n\n    Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);\n\n    Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));\n    tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);\n\n    Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));\n    tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);\n    Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));\n\n    auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};\n    Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));\n    tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));\n    Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));\n\n    // Each thread owns a single row\n    using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem\n    using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x;  // 4x32 threads with 128 cols of 8b elem\n    using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x;   // 4x32 threads with 2 cols of 32b elem\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);\n    Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);\n\n    auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);\n    auto thr_tmem_storev  = tiled_tmem_storev.get_slice(thread_idx);\n\n    Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);\n    Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);\n\n    auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);\n    auto thr_tmem_store  = tiled_tmem_store.get_slice(thread_idx);\n\n    Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);\n    tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());\n    Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);\n\n    // wait on tensor core pipe\n    pipeline_s.consumer_wait(pipeline_s_consumer_state);\n\n    // read all of S from tmem into reg mem\n    Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));\n    copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);\n\n    if constexpr (need_apply_mask) {\n      Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);\n    }\n\n    ElementQK old_row_max = row_max;\n    {\n      // compute rowmax\n      float row_max_0 = row_max;\n      float row_max_1 = row_max;\n      float row_max_2 = row_max;\n      float row_max_3 = row_max;\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {\n        row_max_0  = ::fmax(row_max_0, tTMEM_LOADrS(i));\n        row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));\n        row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));\n        row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));\n      }\n      row_max = ::fmax(row_max_0, row_max_1);\n      row_max = ::fmax(row_max, row_max_2);\n      row_max = ::fmax(row_max, row_max_3);\n    }\n\n    ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;\n\n    Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));\n    tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;\n    tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;\n    copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);\n\n    pipeline_c.producer_commit(pipeline_c_producer_state);\n    ++pipeline_c_producer_state;\n\n    // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)\n\n    ElementQK scale = params.scale_softmax_log2;\n    ElementQK row_max_scale = row_max_safe * scale;\n\n    float2 scale_fp32x2 = make_float2(scale, scale);\n    float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);\n\n    Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));\n\n    constexpr int kConversionsPerStep = 2;\n\n    Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);\n\n    NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;\n\n    const int kReleasePipeCount = 10;  // must be multiple of 2\n\n    order_s.wait();\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {\n      float2 in = make_float2(\n        tTMEM_LOADrS(i + 0),\n        tTMEM_LOADrS(i + 1)\n      );\n      float2 out;\n      cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);\n      tTMEM_LOADrS(i + 0) = out.x;\n      tTMEM_LOADrS(i + 1) = out.y;\n\n      tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));\n      tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));\n\n      Array<ElementQK, kConversionsPerStep> in_conv;\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < kConversionsPerStep; j++) {\n        in_conv[j] = tTMEM_LOADrS(i + j);\n      }\n      tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);\n\n\n      if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {\n        order_s.arrive();\n      }\n\n      // this prevents register spills in fp16\n      if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {\n        if (i == size(tTMEM_LOADrS) - 6) {\n          copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));\n        }\n      }\n    }\n\n    // tmem_store(reg_S8) -> op_P\n    CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});\n    CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});\n    copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));\n\n    cutlass::arch::fence_view_async_tmem_store();\n\n    // notify tensor core warp that P is ready\n    pipeline_s.consumer_release(pipeline_s_consumer_state);\n    ++pipeline_s_consumer_state;\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n\n    ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));\n    row_sum *= acc_scale;\n    // row_sum = sum(reg_S)\n    float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);\n    float2 local_row_sum_1 = make_float2(0, 0);\n    float2 local_row_sum_2 = make_float2(0, 0);\n    float2 local_row_sum_3 = make_float2(0, 0);\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {\n      // row_sum += tTMEM_LOADrS(i);\n      float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));\n      cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);\n\n      in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));\n      cute::add(local_row_sum_1, local_row_sum_1, in);\n\n      in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));\n      cute::add(local_row_sum_2, local_row_sum_2, in);\n\n      in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));\n      cute::add(local_row_sum_3, local_row_sum_3, in);\n    }\n\n    cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);\n    cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);\n    cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);\n    float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;\n\n    row_sum = local_row_sum;\n\n    if (final_call) {\n      // re-acquire the S part in the final step\n      pipeline_s.consumer_wait(pipeline_s_consumer_state);\n\n      Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));\n      tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;\n      tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;\n      copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);\n    }\n  }\n\n  template<class Stage, class BlkCoord, class ProblemShape>\n  CUTLASS_DEVICE auto\n  softmax(\n      Stage stage,\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,\n      PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,\n      OrderBarrierSoftmax& order_s) {\n\n    int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    ElementQK row_max = -INFINITY;\n    ElementQK row_sum = 0;\n\n    Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));\n    auto logical_offset = make_coord(\n        get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),\n        0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})\n    );\n    Tensor cS = domain_offset(logical_offset, cS_base);\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n      softmax_step<false /* need_apply_mask */>(\n          row_max, row_sum, stage,\n          (mask_tile_count == 1) &&\n              (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0),\n          blk_coord, cS, params, problem_shape,\n          pipeline_s, pipeline_s_consumer_state,\n          pipeline_c, pipeline_c_producer_state,\n          order_s\n      );\n\n      cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});\n    }\n\n    // Masked iterations\n    mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n      softmax_step<true /* need_apply_mask */>(\n          row_max, row_sum, stage, mask_tile_count == 1,\n          blk_coord, cS, params, problem_shape,\n          pipeline_s, pipeline_s_consumer_state,\n          pipeline_c, pipeline_c_producer_state,\n          order_s\n      );\n\n      cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});\n    }\n\n    pipeline_c.producer_commit(pipeline_c_producer_state);\n    ++pipeline_c_producer_state;\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n    // empty step to sync against pipe s\n    pipeline_s.consumer_release(pipeline_s_consumer_state);\n    ++pipeline_s_consumer_state;\n  }\n\n  template<class Stage, class TensorO>\n  CUTLASS_DEVICE auto\n  correction_epilogue(\n      float scale,\n      Stage stage,\n      TensorO const& sO_01) {\n\n    using ElementOut = typename TensorO::value_type;\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    Tensor sO = sO_01(_,_,stage);\n\n    // As opposed to the softmax, we do not have enough registers here\n    // to load all of the values (for tile kv = 128), so we loop\n    // good values would be either 32 or 64\n    const int kCorrectionTileSize = 32 / sizeof(ElementOut);\n\n    using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>;  // 4x32 threads with 64 cols of 32b elem\n\n    typename CollectiveMmaPV::TiledMma mma;\n    Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));\n    Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));\n    Tensor tOcO = mma.get_slice(0).partition_C(cO);\n    Tensor tOsO = mma.get_slice(0).partition_C(sO);\n\n    Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n\n    if constexpr (decltype(stage == _0{})::value) {\n      tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);\n    }\n    else {\n      static_assert(decltype(stage == _1{})::value, \"stage is either 0 or 1\");\n      tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);\n    }\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));\n    Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));\n    Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));\n\n    float2 scale_f32x2 = make_float2(scale, scale);\n\n    // loop:\n    //   TMEM_LOAD, FMUL2 scale, TMEM_STORE\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {\n      Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);\n      Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);\n\n      Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));\n\n      copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);\n\n#ifndef ONLY_SOFTMAX\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tTMrO); j += 2) {\n        float2 in = make_float2(tTMrO(j), tTMrO(j+1));\n        float2 out;\n        cute::mul(out, scale_f32x2, in);\n        tTMrO(j) = out.x;\n        tTMrO(j+1) = out.y;\n      }\n#endif\n\n      constexpr int N = 4 / sizeof(ElementOut);\n      NumericArrayConverter<ElementOut, ElementPV, N> convert;\n\n      Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);\n\n      Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);\n      Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tCs); j++) {\n        tCd(j) = convert.convert(tCs(j));\n      }\n\n      Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);\n      Tensor tSMrO_i = recast<uint32_t>(tSMrO);\n\n      copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);\n    }\n\n    cutlass::arch::fence_view_async_shared();\n  }\n\n  CUTLASS_DEVICE auto\n  correction_rescale(\n      float scale,\n      uint32_t tmem_O) {\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    // As opposed to the softmax, we do not have enough registers here\n    // to load all of the values (for tile kv = 128), so we loop\n    // good values would be either 32 or 64\n    const int kCorrectionTileSize = 16;\n\n    using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x;  // 4x32 threads with 64 cols of 32b elem\n    using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x;  // 4x32 threads with 64 cols of 32b elem\n\n    typename CollectiveMmaPV::TiledMma mma;\n    Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));\n    Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));\n    Tensor tOcO = mma.get_slice(0).partition_C(cO);\n\n    Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n\n    tOtO_i.data() = tOtO_i.data().get() + tmem_O;\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n    auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);\n    auto thr_tmem_store   = tiled_tmem_store.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);\n    Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);\n    Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);\n    Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);\n    static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));\n\n    float2 scale_f32x2 = make_float2(scale, scale);\n\n    Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));\n\n    auto copy_in = [&](int i) {\n      Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;\n      tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);\n    };\n\n    auto copy_out = [&](int i) {\n      Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;\n      tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);\n    };\n\n    // sequence: LLMSLMSLMSS\n\n    // loop:\n    //   TMEM_LOAD, FMUL2 scale, TMEM_STORE\n    copy_in(0);\n\n    int count = get<2>(TileShape{}) / kCorrectionTileSize;\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < count; i++) {\n      if (i != count - 1) {\n        copy_in(i+1);\n      }\n\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tTMrO_i); j += 2) {\n        float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));\n        float2 out;\n        cute::mul(out, scale_f32x2, in);\n        tTMrO_i(j) = out.x;\n        tTMrO_i(j+1) = out.y;\n      }\n\n      copy_out(i);\n    }\n  }\n\n  template<\n    class BlkCoord, class ProblemShape, class ParamsProblemShape,\n    class TensorStorageEpi, class CollectiveEpilogue\n  >\n  CUTLASS_DEVICE auto\n  correction(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      ParamsProblemShape const& params_problem_shape,\n      TensorStorageEpi& shared_storage_epi,\n      PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,\n      PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,\n      PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,\n      PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,\n      CollectiveEpilogue& epilogue) {\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));\n\n    Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));\n    Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);\n\n    Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));\n    Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));\n\n    using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x;   // 4x32 threads with 2 cols of 32b elem\n\n    auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);\n    auto thr_tmem_loadv  = tiled_tmem_loadv.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);\n    Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);\n\n    Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;\n    tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);\n    Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;\n    tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);\n\n    // ignore first signal from softmax as no correction is required\n    pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n    pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n    ++pipeline_s0_c_consumer_state;\n\n    pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n    // handle the last iteration differently (i.e. tmem_load/stsm for epi)\n    mask_tile_count -= 1;\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n\n      Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));\n\n      // read row_wise new global max\n      copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);\n\n      // e^(scale * (old_max - new_max)\n      float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));\n\n      pipeline_o.consumer_wait(pipeline_o_consumer_state);\n\n      correction_rescale(scale, uint32_t(TmemAllocation::O0));\n\n      pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n      ++pipeline_s1_c_consumer_state;\n\n      cutlass::arch::fence_view_async_tmem_store();\n\n      pipeline_o.consumer_release(pipeline_o_consumer_state);\n      ++pipeline_o_consumer_state;\n\n      pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n      copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);\n\n      scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));\n\n      pipeline_o.consumer_wait(pipeline_o_consumer_state);\n\n      correction_rescale(scale, uint32_t(TmemAllocation::O1));\n\n      pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n      ++pipeline_s0_c_consumer_state;\n\n      cutlass::arch::fence_view_async_tmem_store();\n\n      pipeline_o.consumer_release(pipeline_o_consumer_state);\n      ++pipeline_o_consumer_state;\n    }\n\n    pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n    ++pipeline_s1_c_consumer_state;\n\n    // do the final correction to O1\n    // better to somehow special-case it in the loop above\n    // doesn't matter for non-persistent code, but if it were\n    // persistent we do not want to release O too early\n\n    pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n\n    // read from V0\n    // read row_sum and final row_max here\n    Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));\n    copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);\n\n    pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n    ++pipeline_s0_c_consumer_state;\n\n    pipeline_o.consumer_wait(pipeline_o_consumer_state);\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n    // store to epi smem\n\n    // loop:\n    //    TMEM_LOAD\n    //    FMUL2 scale = 1 / global_sum * out_quant_scale\n    //    F2FP\n    //    store to smem\n    Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});\n    Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);\n    \n    correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_tmem_load();\n\n    pipeline_o.consumer_release(pipeline_o_consumer_state);\n    ++pipeline_o_consumer_state;\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n\n    pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n    // load from V1\n    copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);\n\n    pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n    ++pipeline_s1_c_consumer_state;\n\n    pipeline_o.consumer_wait(pipeline_o_consumer_state);\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});\n\n      ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_tmem_load();\n\n    pipeline_o.consumer_release(pipeline_o_consumer_state);\n    ++pipeline_o_consumer_state;\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n  }\n\n\n  template<\n    class BlkCoord, class ProblemShape, class ParamsProblemShape,\n    class TensorStorageEpi, class CollectiveEpilogue\n  >\n  CUTLASS_DEVICE auto\n  correction_empty(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      ParamsProblemShape const& params_problem_shape,\n      TensorStorageEpi& shared_storage_epi,\n      PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,\n      CollectiveEpilogue& epilogue) {\n\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});\n    Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);\n    float lse = -INFINITY;\n    int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);\n\n#if 1\n\n    using ElementOut = typename CollectiveEpilogue::ElementOut;\n    auto tiled_copy = make_cotiled_copy(\n        Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},\n        make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),\n        sO.layout());\n\n    auto thr_copy = tiled_copy.get_slice(thread_idx);\n    auto tOgO = thr_copy.partition_D(sO);\n    auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));\n    clear(tOrO);\n    \n    copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));\n#endif\n    \n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n\n    copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));\n    cutlass::arch::fence_view_async_shared();\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_shared();\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n  }\n\n};\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cute/tensor.hpp\"\n#include \"cute/layout.hpp\"\n\n#include \"../collective/fmha_common.hpp\"\n#include \"../collective/fmha_fusion.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\ntemplate<\n  class Element,\n  class StrideQ,\n  class StrideK,\n  class StrideV,\n  class CollectiveMmaQK,\n  class CollectiveMmaPV,\n  class SmemLayoutQ,\n  class SmemLayoutK,\n  class SmemLayoutV,\n  class TensorStorage,\n  class PipelineQ,\n  class PipelineKV,\n  class Mask,\n  class TileShape\n>\nstruct Sm100FmhaLoadTmaWarpspecialized {\n\n  using TileShapeQK = typename CollectiveMmaQK::TileShape;\n  using TileShapePV = typename CollectiveMmaPV::TileShape;\n\n  struct Arguments {\n    const Element* ptr_Q;\n    StrideQ dQ;\n    const Element* ptr_K;\n    StrideK dK;\n    const Element* ptr_V;\n    StrideV dV;\n  };\n\n  using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;\n  using TMA_K = typename CollectiveMmaQK::Params::TMA_B;\n  using TMA_V = typename CollectiveMmaPV::Params::TMA_B;\n\n  struct Params {\n    TMA_Q tma_load_q;\n    TMA_K tma_load_k;\n    TMA_V tma_load_v;\n  };\n\n  template<class ProblemShape>\n  static Params to_underlying_arguments(\n      ProblemShape const& problem_shape,\n      Arguments const& args,\n      void* workspace) {\n\n    auto ptr_Q = args.ptr_Q;\n    auto ptr_K = args.ptr_K;\n    auto ptr_V = args.ptr_V;\n    auto dQ = args.dQ;\n    auto dK = args.dK;\n    auto dV = args.dV;\n\n    using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;\n\n    IntProblemShape problem_shape_qk;\n    if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {\n      auto cumulative_length_q = get<0>(problem_shape).cumulative_length;\n      auto cumulative_length_k = get<1>(problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {\n          get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;\n          get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;\n          get<2>(problem_shape_qk) = get<2>(problem_shape);\n          get<3>(problem_shape_qk) = get<3>(problem_shape);\n      }\n    } else {\n      problem_shape_qk = problem_shape;\n    }\n\n    get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));\n    get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));\n\n    auto params_qk = CollectiveMmaQK::to_underlying_arguments(\n        problem_shape_qk,\n        typename CollectiveMmaQK::Arguments {\n            ptr_Q, dQ,\n            ptr_K, dK,\n        }, /*workspace=*/ nullptr);\n\n    auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);\n    auto params_pv = CollectiveMmaPV::to_underlying_arguments(\n        problem_shape_pv,\n        typename CollectiveMmaPV::Arguments {\n            ptr_K, dK,  // never used, dummy\n            ptr_V, select<1,0,2>(dV),\n        }, /*workspace=*/ nullptr);\n\n    return Params{\n        params_qk.tma_load_a,\n        params_qk.tma_load_b,\n        params_pv.tma_load_b\n    };\n  }\n\n\n  CUTLASS_DEVICE\n  static void prefetch_tma_descriptors(Params const& params) {\n    cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());\n    cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());\n    cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());\n  }\n\n  template<class BlkCoord, class ProblemShape, class ParamsProblemShape>\n  CUTLASS_DEVICE void\n  load(\n      BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,\n      Params const& params, ParamsProblemShape const& params_problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {\n\n    BlkCoord blk_coord_q = blk_coord_in;\n    BlkCoord blk_coord_kv = blk_coord_in;\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);\n\n    using X = Underscore;\n\n    // this one is only executed by one thread, no need to elect_one\n\n    // Q1, K1, Q2, V1, K2, V2, K3, V3, ...\n    // two pipes: Q and KV\n    // from Memory (prod) to TensorCore (cons)\n\n    // compute gQ, sQ\n    // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1\n    ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);\n    Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));\n\n    int q_offs_0 = 0;\n\n    if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n      auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr) {\n        q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];\n        get<2,1>(blk_coord_q) = 0;\n      }\n    }\n\n    Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);\n\n    Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});\n    Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);\n    Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});\n    auto [tQgQ_qdl, tQsQ] = tma_partition(\n      params.tma_load_q, _0{}, make_layout(_1{}), \n      group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)\n    );\n    Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));\n\n    // compute gK, sK\n    Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));\n\n    int kv_offs_0 = 0;\n\n    if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {\n      auto cumulative_length = get<1>(params_problem_shape).cumulative_length;\n      if (cumulative_length != nullptr) {\n        kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];\n        get<2,1>(blk_coord_kv) = 0;\n      }\n    }\n\n    Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);\n\n    Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});\n    Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);\n    Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});\n    auto [tKgK_kdl, tKsK] = tma_partition(\n      params.tma_load_k, _0{}, make_layout(_1{}),\n      group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)\n    );\n    Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));\n\n    // compute gV, sV\n    ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);\n    Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));\n\n    Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);\n\n    Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});\n    Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);\n    Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});\n    auto [tVgV_dkl, tVsV] = tma_partition(\n      params.tma_load_v, _0{}, make_layout(_1{}),\n      group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)\n    );\n    auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));\n\n    // blk_coord in decomposed in terms of TileShape, not TileShapeQK\n    // As such, it needs to be transformed as\n    // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)\n    //          b -> 2*a (Ki i even) 2*a+1 (Ki i odd)\n\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    // Q1\n    int q0_index = 2 * get<0>(blk_coord_q);\n    int q1_index = 2 * get<0>(blk_coord_q) + 1;\n    pipeline_q.producer_acquire(pipeline_q_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);\n      copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));\n    }\n    ++pipeline_q_producer_state;\n\n    // K1\n    int k_index = 0;\n    pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n      copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));\n    }\n    ++pipeline_kv_producer_state;\n\n    // Q2\n    pipeline_q.producer_acquire(pipeline_q_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);\n      copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));\n    }\n    ++pipeline_q_producer_state;\n\n    // V1\n    pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n      copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));\n    }\n    ++pipeline_kv_producer_state;\n    k_index += 1;\n\n    // loop:\n    mask_tile_count -= 1;\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      // Ki\n      pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n      if (lane_predicate) {\n        auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n        copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));\n      }\n      ++pipeline_kv_producer_state;\n\n      // Vi\n      pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n      if (lane_predicate) {\n        auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n        copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));\n      }\n      ++pipeline_kv_producer_state;\n      k_index += 1;\n    }\n  }\n};\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cute/arch/simd_sm100.hpp\"\n#include \"cute/tensor.hpp\"\n#include \"cute/layout.hpp\"\n\n#include \"../collective/fmha_common.hpp\"\n#include \"../collective/fmha_fusion.hpp\"\n#include \"../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp\"\n#include \"../common/pipeline_mla.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\ntemplate<\n  class Element_,\n  class ElementQK_,\n  class ElementPV_,\n  class ComposedTileShape_,\n  class StrideQ_,\n  class StrideK_,\n  class StrideV_,\n  class Mask_,\n  // shape here is QG K H\n  // and referes to the two softmax warps\n  // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)\n  // (1, 2, 1) means they sit side by side (best for small Q / large K)\n  class ThreadShape = Shape<_2, _1, _1>,\n  class OrderLoadEpilogue = cute::false_type\n>\nstruct Sm100MlaFwdMainloopTmaWarpspecialized {\n\n  using Element = Element_;\n  using ElementQK = ElementQK_;\n  using ElementPV = ElementPV_;\n  using ComposedTileShape = ComposedTileShape_;\n  using StrideQ = StrideQ_;\n  using StrideK = StrideK_;\n  using StrideV = StrideV_;\n  using Mask = Mask_;\n\n  static constexpr int StageCountQ = 2;\n  static constexpr int StageCountK = 1;\n  static constexpr int StageCountV = 1;\n  static constexpr int StageCountKV = StageCountK + StageCountV;\n  // Support StageCountKV > 2 in the future. \n  static_assert(StageCountK == 1 && StageCountV == 1, \"Only support StageCountK = StageCountV = 1!\");\n  static_assert(std::is_same_v<ThreadShape, Shape<_2, _1, _1>>, \"Only support ThreadShape = Shape<_2, _1, _1>\");\n\n  using ClusterShape = Shape<_1, _1, _1>;\n\n  static const int Alignment = 128 / sizeof_bits_v<Element>;\n\n  static constexpr auto  HeadDimLatent = size<2, 0>(ComposedTileShape{});\n  static constexpr auto  HeadDimRope = size<2, 1>(ComposedTileShape{});\n  static constexpr auto  HeadDimQK = HeadDimLatent + HeadDimRope;\n  static constexpr auto  HeadDimPV = HeadDimLatent;\n\n  using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{}));\n  using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{})));\n  using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent));\n\n  using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, StrideQ, Alignment,\n      Element, StrideK, Alignment,\n      ElementQK,\n      TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,\n      cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;\n\n  using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // the stride for A does not matter since we do not load from smem at all\n      Element, StrideK, Alignment,\n      Element, decltype(select<1,0,2>(StrideV{})), Alignment,\n      ElementPV,\n      TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,\n      cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;\n\n  using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));\n  using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountK>{}));\n  using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountV>{}));\n\n  using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{})));\n  \n  // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, \n  // we reuse shared memory for V and O to address this problem, \n  // and a barrier has been added to coordinate access to shared memory.\n  static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;\n  static const int NumWarpsEpilogue = 1;\n  static const int NumWarpsLoad = 1;\n  \n  struct TensorStorageQKVO {\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k; \n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_o; // use as O0\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // use as V0 and O1\n  };\n\n  struct TensorStorageQKV {\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k; \n    cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n  };\n\n  using TensorStorage = std::conditional_t<IsOrderLoadEpilogue, TensorStorageQKVO, TensorStorageQKV>;\n\n  enum class TmemAllocation : uint32_t {\n    kSizeS = 128,\n    kSizeO = 128,\n    kSizeP = 32,\n    S0 = 0,\n    S1 = S0 + kSizeS,\n    V0 = S0,  // stats storage from softmax to correction\n    V1 = S1,\n    P0 = S0 + kSizeP,\n    P1 = S1 + kSizeP,\n    O0 = S1 + kSizeS,\n    O1 = O0 + kSizeO,\n    kEnd = O1 + kSizeO\n  };\n\n  // indices for V0 / V1\n  enum : int {\n    kIdxOldRowMax = 0,\n    kIdxNewRowMax = 1,\n    kIdxFinalRowSum = 0,\n    kIdxFinalRowMax = 1\n  };\n\n  // from load to mma warp, protects q in smem\n  using PipelineQ = cutlass::PipelineTmaUmmaAsync<\n    StageCountQ,\n    typename CollectiveMmaQK::AtomThrShapeMNK\n  >;\n\n  // from load to mma warp, protects k/v in smem\n  using PipelineKV = cutlass::PipelineTmaAsyncMla<\n    StageCountKV,\n    typename CollectiveMmaQK::AtomThrShapeMNK\n  >;\n\n  // from mma to softmax0/1 warp, protects S in tmem\n  // (not sure yet about the reverse direction)\n  // there is one pipe per softmax warp, and the mma warp alternates between them\n  using PipelineS = cutlass::PipelineUmmaAsync<1>;\n\n  // from softmax0/1/ to correction wg\n  using PipelineC = cutlass::PipelineAsync<1>;\n\n  // from mma to correction\n  using PipelineO = cutlass::PipelineUmmaAsync<2>;\n\n  // from corr to epilogue\n  using PipelineE = cutlass::PipelineAsync<2>;\n\n  using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<\n    /*stages*/ 1, /*groups*/ 2>;\n\n  static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);\n\n  using Load = Sm100MlaFwdLoadTmaWarpspecialized<\n    Element, StrideQ, StrideK, StrideV,\n    CollectiveMmaQK, CollectiveMmaPV,\n    SmemLayoutQ, SmemLayoutK, SmemLayoutV,\n    TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue\n  >;\n\n  struct Arguments {\n    typename Load::Arguments load;\n\n    // if zero, defaults to 1/sqrt(D)\n    float scale_softmax = 0.0f;\n\n    // scaling factors to dequantize QKV\n    float scale_q = 1.0f;\n    float scale_k = 1.0f;\n    float scale_v = 1.0f;\n\n    // scaling factor to quantize O\n    float inv_scale_o = 1.0f;\n  };\n\n  struct Params {\n    typename Load::Params load;\n\n    float scale_softmax;\n    float scale_softmax_log2;\n\n    float scale_output;\n  };\n\n  template<class ProblemShape>\n  static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {\n    return true;\n  }\n\n  template<class ProblemShape>\n  static Params to_underlying_arguments(\n      ProblemShape const& problem_shape,\n      Arguments const& args,\n      void* workspace) {\n\n    float scale_softmax = args.scale_softmax;\n    if (scale_softmax == 0.0f) {\n      scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape));\n    }\n    float log2_e = static_cast<float>(std::log2(std::exp(1.0)));\n\n    return Params{\n        Load::to_underlying_arguments(problem_shape, args.load, workspace),\n        args.scale_q * args.scale_k * scale_softmax,\n        args.scale_q * args.scale_k * log2_e * scale_softmax,\n        args.scale_v * args.inv_scale_o\n    };\n  }\n\n  CUTLASS_DEVICE\n  static void prefetch_tma_descriptors(Params const& params) {\n      Load::prefetch_tma_descriptors(params.load);\n  }\n\n  template<class BlkCoord, class ProblemShape, class ParamsProblemShape>\n  CUTLASS_DEVICE void\n  load(\n      BlkCoord const& blk_coord, ProblemShape const& problem_shape,\n      Params const& params, ParamsProblemShape const& params_problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {\n\n    Load load;\n    load.load(blk_coord, problem_shape, params.load, params_problem_shape,\n        storage,\n        pipeline_q, pipeline_q_producer_state,\n        pipeline_kv, pipeline_kv_producer_state);\n  }\n\n  template<class BlkCoord, class ProblemShape>\n  CUTLASS_DEVICE auto\n  mma(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,\n      PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,\n      PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,\n      PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {\n\n    auto pipeline_q_release_state = pipeline_q_consumer_state;\n    auto pipeline_kv_release_state = pipeline_kv_consumer_state;\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    typename CollectiveMmaQK::TiledMma mma_qk;\n    ThrMMA thr_mma_qk = mma_qk.get_slice(0);\n\n    typename CollectiveMmaPV::TiledMma mma_pv;\n    TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);\n    ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);\n\n    Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});\n    Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});\n    Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});\n\n    Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);\n    Tensor tSrK = thr_mma_qk.make_fragment_B(sK);\n    Tensor tOrV = thr_mma_pv.make_fragment_B(sV);\n\n    // tmem layout is\n    // S0 S1`O0 O1\n    // sequential in memory, where S overlaps with P and V\n\n    Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));\n    Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));\n\n    Tensor tStS0 = tStS;\n    tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);\n    Tensor tStS1 = tStS;\n    tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);\n\n    Tensor tOtO0 = tOtO;\n    tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);\n    Tensor tOtO1 = tOtO;\n    tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);\n\n    Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});\n    Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{});  // slice out staging\n\n    Tensor tOrP0 = tOrP;\n    tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);\n    Tensor tOrP1 = tOrP;\n    tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);\n\n    int k_index = 0;\n    int v_index = 0;\n    int q_index = 0;\n\n    // wait for Q1\n    q_index = pipeline_q_consumer_state.index();\n    pipeline_q.consumer_wait(pipeline_q_consumer_state);\n    ++pipeline_q_consumer_state;\n\n    Tensor tSrQ0 = tSrQ(_,_,_,q_index);\n\n\n    // wait for K1\n    k_index = pipeline_kv_consumer_state.index();\n    pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n    ++pipeline_kv_consumer_state;\n\n    // gemm Q1 * K1 -> S1\n    pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n    gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);\n\n    pipeline_s0.producer_commit(pipeline_s0_producer_state);\n    ++pipeline_s0_producer_state;\n\n    // release K1\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n    }\n\n    // wait for Q2\n    if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {\n      q_index = pipeline_q_consumer_state.index();\n      pipeline_q.consumer_wait(pipeline_q_consumer_state);\n      ++pipeline_q_consumer_state;\n    }\n\n    Tensor tSrQ1 = tSrQ(_,_,_,q_index);\n\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      k_index = pipeline_kv_consumer_state.index();\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n    }\n\n    pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n    // gemm Q2 * K1 -> S2\n    gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);\n\n    pipeline_s1.producer_commit(pipeline_s1_producer_state);\n    ++pipeline_s1_producer_state;\n\n    // release K1\n    pipeline_kv.consumer_release(pipeline_kv_release_state);\n    ++pipeline_kv_release_state;\n\n    // wait for V1\n    v_index = pipeline_kv_consumer_state.index();\n    pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n    ++pipeline_kv_consumer_state;\n\n    // this acquire returns the ownership of all of S0 to the mma warp\n    // including the P0 part\n    // acquire corr first to take it out of the critical\n    // path since softmax takes longer\n    pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n    pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n    // gemm P1 * V1 -> O1\n    gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);\n\n    pipeline_corr.producer_commit(pipeline_corr_producer_state);\n    ++pipeline_corr_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n    }\n\n    mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;\n\n    // loop:\n    mask_tile_count -= 1;\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      // wait for Ki\n      k_index = (pipeline_kv_consumer_state.index());\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n\n      // gemm Q1 * Ki -> S1\n      gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);\n\n      pipeline_s0.producer_commit(pipeline_s0_producer_state);\n      ++pipeline_s0_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        pipeline_kv.consumer_release(pipeline_kv_release_state);\n        ++pipeline_kv_release_state;\n      }\n\n      // gemm P2 * V(i-1) -> O2\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        v_index = pipeline_kv_consumer_state.index();\n        pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n        ++pipeline_kv_consumer_state;\n      }\n\n      pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n      pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n      gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);\n\n      pipeline_corr.producer_commit(pipeline_corr_producer_state);\n      ++pipeline_corr_producer_state;\n\n      // release V(i-1)\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        k_index = (pipeline_kv_consumer_state.index());\n        pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n        ++pipeline_kv_consumer_state;\n      }\n\n      // gemm Q2 * Ki -> S2\n      gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);\n\n      pipeline_s1.producer_commit(pipeline_s1_producer_state);\n      ++pipeline_s1_producer_state;\n\n      // release Ki\n      pipeline_kv.consumer_release(pipeline_kv_release_state);\n      ++pipeline_kv_release_state;\n\n      // wait for Vi\n      v_index = (pipeline_kv_consumer_state.index());\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n\n      // gemm P1 * Vi -> O1\n      pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n\n      pipeline_s0.producer_acquire(pipeline_s0_producer_state);\n\n      gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);\n\n      pipeline_corr.producer_commit(pipeline_corr_producer_state);\n      ++pipeline_corr_producer_state;\n\n      if constexpr (get<1>(ThreadShape{}) > 1) {\n        pipeline_kv.consumer_release(pipeline_kv_release_state);\n        ++pipeline_kv_release_state;\n      }\n    }\n\n    // release Q1\n    pipeline_q.consumer_release(pipeline_q_release_state);\n    ++pipeline_q_release_state;\n\n    // release Q2\n    if constexpr (get<0>(ThreadShape{}) > 1) {\n      pipeline_q.consumer_release(pipeline_q_release_state);\n      ++pipeline_q_release_state;\n    }\n\n    // wait for Vi\n    if constexpr (get<1>(ThreadShape{}) > 1) {\n      v_index = pipeline_kv_consumer_state.index();\n      pipeline_kv.consumer_wait(pipeline_kv_consumer_state);\n      ++pipeline_kv_consumer_state;\n    }\n\n    // gemm P2 * Vi -> O2\n    pipeline_corr.producer_acquire(pipeline_corr_producer_state);\n    pipeline_s1.producer_acquire(pipeline_s1_producer_state);\n\n    gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);\n\n    pipeline_corr.producer_commit(pipeline_corr_producer_state);\n    ++pipeline_corr_producer_state;\n\n    // release Vi\n    pipeline_kv.consumer_release(pipeline_kv_release_state);\n    ++pipeline_kv_release_state;\n\n    pipeline_s0.producer_commit(pipeline_s0_producer_state);\n    ++pipeline_s0_producer_state;\n\n    pipeline_s1.producer_commit(pipeline_s1_producer_state);\n    ++pipeline_s1_producer_state;\n\n    // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...\n    // Q1 * K1  , Q2 * K1  , S11 * V1 , Q1 * K2  , S21 * V1  , Q2 * K2 , S12 * V2 , Q1 * K3  , S22 * K2 , ...\n  }\n\n  template<bool need_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>\n  CUTLASS_DEVICE auto\n  softmax_step(\n      bool need_apply_mask,\n      float& row_max, float& row_sum,\n      Stage stage, bool final_call,\n      BlkCoord const& blk_coord, CoordTensor const& cS,\n      Params const& params, ProblemShape const& problem_shape,\n      PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,\n      PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,\n      OrderBarrierSoftmax& order_s) {\n\n    Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);\n\n    Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));\n    tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);\n\n    Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));\n    tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);\n    Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));\n\n    auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};\n    Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));\n    tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));\n    Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));\n\n    // Each thread owns a single row\n    using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem\n    using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x;  // 4x32 threads with 128 cols of 8b elem\n    using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x;   // 4x32 threads with 2 cols of 32b elem\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);\n    Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);\n\n    auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);\n    auto thr_tmem_storev  = tiled_tmem_storev.get_slice(thread_idx);\n\n    Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);\n    Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);\n\n    auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);\n    auto thr_tmem_store  = tiled_tmem_store.get_slice(thread_idx);\n\n    Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);\n    tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());\n    Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);\n\n    // wait on tensor core pipe\n    pipeline_s.consumer_wait(pipeline_s_consumer_state);\n\n    // read all of S from tmem into reg mem\n    Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));\n    copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);\n\n    if constexpr (need_mask) {\n      if(need_apply_mask) {\n        Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);\n      }\n    }\n\n    ElementQK old_row_max = row_max;\n    {\n      // compute rowmax\n      float row_max_0 = row_max;\n      float row_max_1 = row_max;\n      float row_max_2 = row_max;\n      float row_max_3 = row_max;\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {\n        row_max_0  = ::fmax(row_max_0, tTMEM_LOADrS(i));\n        row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));\n        row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));\n        row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));\n      }\n      row_max = ::fmax(row_max_0, row_max_1);\n      row_max = ::fmax(row_max, row_max_2);\n      row_max = ::fmax(row_max, row_max_3);\n    }\n\n    ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;\n\n    Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));\n    tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;\n    tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;\n    copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);\n\n    pipeline_c.producer_commit(pipeline_c_producer_state);\n    ++pipeline_c_producer_state;\n\n    // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)\n\n    ElementQK scale = params.scale_softmax_log2;\n    ElementQK row_max_scale = row_max_safe * scale;\n\n    float2 scale_fp32x2 = make_float2(scale, scale);\n    float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);\n\n    Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));\n\n    constexpr int kConversionsPerStep = 2;\n\n    Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);\n\n    NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;\n\n    constexpr int kReleasePipeCount = 10;  // must be multiple of 2\n\n    order_s.wait();\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {\n      float2 in = make_float2(\n        tTMEM_LOADrS(i + 0),\n        tTMEM_LOADrS(i + 1)\n      );\n      float2 out;\n      cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);\n      tTMEM_LOADrS(i + 0) = out.x;\n      tTMEM_LOADrS(i + 1) = out.y;\n\n      tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));\n      tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));\n\n      Array<ElementQK, kConversionsPerStep> in_conv;\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < kConversionsPerStep; j++) {\n        in_conv[j] = tTMEM_LOADrS(i + j);\n      }\n      tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);\n\n\n      if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {\n        order_s.arrive();\n      }\n\n      // this prevents register spills in fp16\n      if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {\n        if (i == size(tTMEM_LOADrS) - 6) {\n          copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));\n        }\n      }\n    }\n\n    // tmem_store(reg_S8) -> op_P\n    CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});\n    CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});\n    copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));\n\n    cutlass::arch::fence_view_async_tmem_store();\n\n    // notify tensor core warp that P is ready\n    pipeline_s.consumer_release(pipeline_s_consumer_state);\n    ++pipeline_s_consumer_state;\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n\n    ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));\n    row_sum *= acc_scale;\n    // row_sum = sum(reg_S)\n    float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);\n    float2 local_row_sum_1 = make_float2(0, 0);\n    float2 local_row_sum_2 = make_float2(0, 0);\n    float2 local_row_sum_3 = make_float2(0, 0);\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {\n      // row_sum += tTMEM_LOADrS(i);\n      float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));\n      cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);\n\n      in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));\n      cute::add(local_row_sum_1, local_row_sum_1, in);\n\n      in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));\n      cute::add(local_row_sum_2, local_row_sum_2, in);\n\n      in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));\n      cute::add(local_row_sum_3, local_row_sum_3, in);\n    }\n\n    cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);\n    cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);\n    cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);\n    float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;\n\n    row_sum = local_row_sum;\n\n    if (final_call) {\n      // re-acquire the S part in the final step\n      pipeline_s.consumer_wait(pipeline_s_consumer_state);\n\n      Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));\n      tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;\n      tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;\n      copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);\n    }\n  }\n\n  template<class Stage, class BlkCoord, class ProblemShape>\n  CUTLASS_DEVICE auto\n  softmax(\n      Stage stage,\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,\n      PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,\n      OrderBarrierSoftmax& order_s) {\n    const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);\n    const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);\n    int trip_idx = total_trip_count;\n\n    ElementQK row_max = -INFINITY;\n    ElementQK row_sum = 0;\n\n    Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));\n    auto logical_offset = make_coord(\n        get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),\n        0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})\n    );\n    Tensor cS = domain_offset(logical_offset, cS_base);\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n    \n    constexpr bool NeedMask = !std::is_same_v<Mask, NoMask>;\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    for (; trip_idx > 0; trip_idx -= 1) {\n      softmax_step<NeedMask /* need_mask */>(\n          trip_idx <= mask_trip_count,\n          row_max, row_sum, stage,\n          trip_idx == 1,\n          blk_coord, cS, params, problem_shape,\n          pipeline_s, pipeline_s_consumer_state,\n          pipeline_c, pipeline_c_producer_state,\n          order_s\n      );\n\n      cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});\n    }\n\n    pipeline_c.producer_commit(pipeline_c_producer_state);\n    ++pipeline_c_producer_state;\n\n    pipeline_c.producer_acquire(pipeline_c_producer_state);\n    // empty step to sync against pipe s\n    pipeline_s.consumer_release(pipeline_s_consumer_state);\n    ++pipeline_s_consumer_state;\n  }\n\n  template<class Stage, class TensorO>\n  CUTLASS_DEVICE auto\n  correction_epilogue(\n      float scale,\n      Stage stage,\n      TensorO const& sO_01) {\n\n    using ElementOut = typename TensorO::value_type;\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    Tensor sO = sO_01(_,_,stage);\n\n    // As opposed to the softmax, we do not have enough registers here\n    // to load all of the values (for tile kv = 128), so we loop\n    // good values would be either 32 or 64\n    constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut);\n\n    using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>;  // 4x32 threads with 64 cols of 32b elem\n\n    typename CollectiveMmaPV::TiledMma mma;\n    Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));\n    Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));\n    Tensor tOcO = mma.get_slice(0).partition_C(cO);\n    Tensor tOsO = mma.get_slice(0).partition_C(sO);\n\n    Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n\n    if constexpr (decltype(stage == _0{})::value) {\n      tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);\n    }\n    else {\n      static_assert(decltype(stage == _1{})::value, \"stage is either 0 or 1\");\n      tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);\n    }\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));\n    Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));\n    Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));\n\n    float2 scale_f32x2 = make_float2(scale, scale);\n\n    // loop:\n    //   TMEM_LOAD, FMUL2 scale, TMEM_STORE\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {\n      Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);\n      Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);\n\n      Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));\n\n      copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);\n\n#ifndef ONLY_SOFTMAX\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tTMrO); j += 2) {\n        float2 in = make_float2(tTMrO(j), tTMrO(j+1));\n        float2 out;\n        cute::mul(out, scale_f32x2, in);\n        tTMrO(j) = out.x;\n        tTMrO(j+1) = out.y;\n      }\n#endif\n\n      constexpr int N = 4 / sizeof(ElementOut);\n      NumericArrayConverter<ElementOut, ElementPV, N> convert;\n\n      Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);\n\n      Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);\n      Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tCs); j++) {\n        tCd(j) = convert.convert(tCs(j));\n      }\n\n      Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);\n      Tensor tSMrO_i = recast<uint32_t>(tSMrO);\n\n      copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);\n    }\n\n    cutlass::arch::fence_view_async_shared();\n  }\n\n  CUTLASS_DEVICE auto\n  correction_rescale(\n      float scale,\n      uint32_t tmem_O) {\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    // As opposed to the softmax, we do not have enough registers here\n    // to load all of the values (for tile kv = 128), so we loop\n    // good values would be either 32 or 64\n    constexpr int kCorrectionTileSize = 16;\n\n    using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x;  // 4x32 threads with 64 cols of 32b elem\n    using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x;  // 4x32 threads with 64 cols of 32b elem\n\n    typename CollectiveMmaPV::TiledMma mma;\n    Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));\n    Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));\n    Tensor tOcO = mma.get_slice(0).partition_C(cO);\n\n    Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n    Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));\n\n    tOtO_i.data() = tOtO_i.data().get() + tmem_O;\n\n    auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);\n    auto thr_tmem_load   = tiled_tmem_load.get_slice(thread_idx);\n    auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);\n    auto thr_tmem_store   = tiled_tmem_store.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);\n    Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);\n    Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);\n    Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);\n    static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));\n\n    float2 scale_f32x2 = make_float2(scale, scale);\n\n    Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));\n\n    auto copy_in = [&](int i) {\n      Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;\n      tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);\n    };\n\n    auto copy_out = [&](int i) {\n      Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;\n      tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);\n    };\n\n    // sequence: LLMSLMSLMSS\n\n    // loop:\n    //   TMEM_LOAD, FMUL2 scale, TMEM_STORE\n    copy_in(0);\n\n    constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize;\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < count; i++) {\n      if (i != count - 1) {\n        copy_in(i+1);\n      }\n\n      Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));\n      CUTLASS_PRAGMA_UNROLL\n      for (int j = 0; j < size(tTMrO_i); j += 2) {\n        float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));\n        float2 out;\n        cute::mul(out, scale_f32x2, in);\n        tTMrO_i(j) = out.x;\n        tTMrO_i(j+1) = out.y;\n      }\n\n      copy_out(i);\n    }\n  }\n\n  template<\n    class BlkCoord, class ProblemShape, class ParamsProblemShape,\n    class TensorStorageEpi, class CollectiveEpilogue\n  >\n  CUTLASS_DEVICE auto\n  correction(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      ParamsProblemShape const& params_problem_shape,\n      TensorStorageEpi& shared_storage_epi,\n      PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,\n      PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,\n      PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,\n      PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,\n      CollectiveEpilogue& epilogue) {\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);\n\n    int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);\n\n    Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));\n\n    Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));\n    Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);\n\n    Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));\n    Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));\n\n    using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x;   // 4x32 threads with 2 cols of 32b elem\n\n    auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);\n    auto thr_tmem_loadv  = tiled_tmem_loadv.get_slice(thread_idx);\n\n    Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);\n    Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);\n\n    Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;\n    tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);\n    Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;\n    tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);\n\n    // ignore first signal from softmax as no correction is required\n    pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n    pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n    ++pipeline_s0_c_consumer_state;\n\n    pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n    // handle the last iteration differently (i.e. tmem_load/stsm for epi)\n    mask_tile_count -= 1;\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n\n      Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));\n\n      // read row_wise new global max\n      copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);\n\n      // e^(scale * (old_max - new_max)\n      float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));\n\n      pipeline_o.consumer_wait(pipeline_o_consumer_state);\n\n      correction_rescale(scale, uint32_t(TmemAllocation::O0));\n\n      pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n      ++pipeline_s1_c_consumer_state;\n\n      cutlass::arch::fence_view_async_tmem_store();\n\n      pipeline_o.consumer_release(pipeline_o_consumer_state);\n      ++pipeline_o_consumer_state;\n\n      pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n      copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);\n\n      scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));\n\n      pipeline_o.consumer_wait(pipeline_o_consumer_state);\n\n      correction_rescale(scale, uint32_t(TmemAllocation::O1));\n\n      pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n      ++pipeline_s0_c_consumer_state;\n\n      cutlass::arch::fence_view_async_tmem_store();\n\n      pipeline_o.consumer_release(pipeline_o_consumer_state);\n      ++pipeline_o_consumer_state;\n    }\n\n    pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n    ++pipeline_s1_c_consumer_state;\n\n    // do the final correction to O1\n    // better to somehow special-case it in the loop above\n    // doesn't matter for non-persistent code, but if it were\n    // persistent we do not want to release O too early\n\n    pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);\n\n    // read from V0\n    // read row_sum and final row_max here\n    Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));\n    copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);\n\n    pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);\n    ++pipeline_s0_c_consumer_state;\n\n    pipeline_o.consumer_wait(pipeline_o_consumer_state);\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n    // store to epi smem\n\n    // loop:\n    //    TMEM_LOAD\n    //    FMUL2 scale = 1 / global_sum * out_quant_scale\n    //    F2FP\n    //    store to smem\n    Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});\n    Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);\n    correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_tmem_load();\n\n    pipeline_o.consumer_release(pipeline_o_consumer_state);\n    ++pipeline_o_consumer_state;\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n\n    pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);\n\n    // load from V1\n    copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);\n\n    pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);\n    ++pipeline_s1_c_consumer_state;\n\n    pipeline_o.consumer_wait(pipeline_o_consumer_state);\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});\n\n      ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_tmem_load();\n\n    pipeline_o.consumer_release(pipeline_o_consumer_state);\n    ++pipeline_o_consumer_state;\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n  }\n\n\n  template<\n    class BlkCoord, class ProblemShape, class ParamsProblemShape,\n    class TensorStorageEpi, class CollectiveEpilogue\n  >\n  CUTLASS_DEVICE auto\n  correction_empty(\n      BlkCoord const& blk_coord,\n      Params const& params, ProblemShape const& problem_shape,\n      ParamsProblemShape const& params_problem_shape,\n      TensorStorageEpi& shared_storage_epi,\n      PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,\n      CollectiveEpilogue& epilogue) {\n\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});\n    Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);\n    float lse = -INFINITY;\n    int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);\n\n#if 1\n\n    using ElementOut = typename CollectiveEpilogue::ElementOut;\n    auto tiled_copy = make_cotiled_copy(\n        Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},\n        make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),\n        sO.layout());\n\n    auto thr_copy = tiled_copy.get_slice(thread_idx);\n    auto tOgO = thr_copy.partition_D(sO);\n    auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));\n    clear(tOrO);\n    \n    copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));\n#endif\n    \n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n\n    copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));\n    cutlass::arch::fence_view_async_shared();\n    pipeline_epi.producer_acquire(pipeline_epi_producer_state);\n\n    if (epilogue.params.ptr_LSE != nullptr) {\n      int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});\n\n      int row_offset = 0;\n      if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n        row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];\n      }\n\n      if (row_idx < get<0>(problem_shape)) {\n        gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;\n      }\n    }\n\n    cutlass::arch::fence_view_async_shared();\n    pipeline_epi.producer_commit(pipeline_epi_producer_state);\n    ++pipeline_epi_producer_state;\n  }\n\n};\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n#include \"cute/tensor.hpp\"\n#include \"cute/layout.hpp\"\n\n#include \"../collective/fmha_common.hpp\"\n#include \"../collective/fmha_fusion.hpp\"\n\nnamespace cutlass::fmha::collective {\n\nusing namespace cute;\n\ntemplate<\n  class Element,\n  class StrideQ,\n  class StrideK,\n  class StrideV,\n  class CollectiveMmaQK,\n  class CollectiveMmaPV,\n  class SmemLayoutQ,\n  class SmemLayoutK,\n  class SmemLayoutV,\n  class TensorStorage,\n  class PipelineQ,\n  class PipelineKV,\n  class Mask,\n  class TileShape,\n  class OrderLoadEpilogue = cute::false_type\n>\nstruct Sm100MlaFwdLoadTmaWarpspecialized {\n\n  using TileShapeQK = typename CollectiveMmaQK::TileShape;\n  using TileShapePV = typename CollectiveMmaPV::TileShape;\n\n  static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);\n\n  static const int NumWarpsEpilogue = 1;\n  static const int NumWarpsLoad = 1;\n\n  struct Arguments {\n    const Element* ptr_Q;\n    StrideQ dQ;\n    const Element* ptr_K;\n    StrideK dK;\n    const Element* ptr_V;\n    StrideV dV;\n  };\n\n  using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;\n  using TMA_K = typename CollectiveMmaQK::Params::TMA_B;\n  using TMA_V = typename CollectiveMmaPV::Params::TMA_B;\n\n  struct Params {\n    TMA_Q tma_load_q;\n    TMA_K tma_load_k;\n    TMA_V tma_load_v;\n  };\n\n  template<class ProblemShape>\n  static Params to_underlying_arguments(\n      ProblemShape const& problem_shape,\n      Arguments const& args,\n      void* workspace) {\n\n    auto ptr_Q = args.ptr_Q;\n    auto ptr_K = args.ptr_K;\n    auto ptr_V = args.ptr_V;\n    auto dQ = args.dQ;\n    auto dK = args.dK;\n    auto dV = args.dV;\n\n    using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;\n\n    IntProblemShape problem_shape_qk;\n    if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {\n      auto cumulative_length_q = get<0>(problem_shape).cumulative_length;\n      auto cumulative_length_k = get<1>(problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {\n          get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;\n          get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;\n          get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);\n          get<3>(problem_shape_qk) = get<3>(problem_shape);\n      }\n    } else {\n      problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;\n    }\n\n    get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));\n    get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));\n\n    auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));\n\n    auto params_qk = CollectiveMmaQK::to_underlying_arguments(\n        problem_shape_qk,\n        typename CollectiveMmaQK::Arguments {\n            ptr_Q, dQ,\n            ptr_K, dK,\n        }, /*workspace=*/ nullptr);\n\n    auto params_pv = CollectiveMmaPV::to_underlying_arguments(\n        problem_shape_pv,\n        typename CollectiveMmaPV::Arguments {\n            ptr_K, dK,  // never used, dummy\n            ptr_V, select<1,0,2>(dV),\n        }, /*workspace=*/ nullptr);\n\n    return Params{\n        params_qk.tma_load_a,\n        params_qk.tma_load_b,\n        params_pv.tma_load_b\n    };\n  }\n\n\n  CUTLASS_DEVICE\n  static void prefetch_tma_descriptors(Params const& params) {\n    cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());\n    cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());\n    cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());\n  }\n\n  template<class BlkCoord, class ProblemShape, class ParamsProblemShape>\n  CUTLASS_DEVICE void\n  load(\n      BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,\n      Params const& params, ParamsProblemShape const& params_problem_shape,\n      TensorStorage& storage,\n      PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,\n      PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {\n\n    BlkCoord blk_coord_q = blk_coord_in;\n    BlkCoord blk_coord_kv = blk_coord_in;\n\n    auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));\n    auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));\n\n    int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);\n\n    using X = Underscore;\n\n    // this one is only executed by one thread, no need to elect_one\n\n    // Q1, K1, Q2, V1, K2, V2, K3, V3, ...\n    // two pipes: Q and KV\n    // from Memory (prod) to TensorCore (cons)\n\n    // compute gQ, sQ\n    // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1\n    ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);\n    Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));\n\n    int q_offs_0 = 0;\n\n    if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {\n      auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;\n      if (cumulative_length_q != nullptr) {\n        q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];\n        get<2,1>(blk_coord_q) = 0;\n      }\n    }\n\n    Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);\n\n    Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});\n    Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);\n    Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});\n    auto [tQgQ_qdl, tQsQ] = tma_partition(\n      params.tma_load_q, _0{}, make_layout(_1{}), \n      group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)\n    );\n    Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));\n\n    // compute gK, sK\n    Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));\n\n    int kv_offs_0 = 0;\n\n    if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {\n      auto cumulative_length = get<1>(params_problem_shape).cumulative_length;\n      if (cumulative_length != nullptr) {\n        kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];\n        get<2,1>(blk_coord_kv) = 0;\n      }\n    }\n\n    Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);\n\n    Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});\n    Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);\n    Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});\n    auto [tKgK_kdl, tKsK] = tma_partition(\n      params.tma_load_k, _0{}, make_layout(_1{}),\n      group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)\n    );\n    Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));\n\n    // compute gV, sV\n    ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);\n    Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));\n\n    Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);\n\n    Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});\n    Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);\n    Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});\n    auto [tVgV_dkl, tVsV] = tma_partition(\n      params.tma_load_v, _0{}, make_layout(_1{}),\n      group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)\n    );\n    auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));\n\n    // blk_coord in decomposed in terms of TileShape, not TileShapeQK\n    // As such, it needs to be transformed as\n    // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)\n    //          b -> 2*a (Ki i even) 2*a+1 (Ki i odd)\n\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    // Q1\n    int q0_index = 2 * get<0>(blk_coord_q);\n    int q1_index = 2 * get<0>(blk_coord_q) + 1;\n    pipeline_q.producer_acquire(pipeline_q_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);\n      copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));\n    }\n    ++pipeline_q_producer_state;\n\n    // K1\n    int k_index = 0;\n    pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n      copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));\n    }\n    ++pipeline_kv_producer_state;\n\n    // Q2\n    pipeline_q.producer_acquire(pipeline_q_producer_state);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);\n      copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));\n    }\n    ++pipeline_q_producer_state;\n\n    if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {\n      cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, \n                                        cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n    }\n\n    // V1\n    pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);\n    if (lane_predicate) {\n      auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n      copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));\n    }\n    ++pipeline_kv_producer_state;\n    k_index += 1;\n\n    // loop:\n    mask_tile_count -= 1;\n    for (; mask_tile_count > 0; mask_tile_count -= 1) {\n\n      // Ki\n      pipeline_kv.producer_acquire(pipeline_kv_producer_state);\n      if (lane_predicate) {\n        auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n        copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));\n\n        // prefetch vi\n        cute::prefetch(params.tma_load_v, tVgV(_, k_index));\n      }\n      ++pipeline_kv_producer_state;\n\n      // Vi\n      pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);\n      if (lane_predicate) {\n        auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);\n        copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));\n\n        // prefetch ki+1\n        if(mask_tile_count > 1) {\n          cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));\n        }\n      }\n      ++pipeline_kv_producer_state;\n      k_index += 1;\n    }\n  }\n};\n\n}  // namespace cutlass::fmha::collective\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/common/gather_tensor.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cute/layout.hpp\"\n#include \"cute/tensor.hpp\"\n#include \"cute/util/print.hpp\"\n\nnamespace example {\n\nusing namespace cute;\n\n// Empty type used to disable gather/scatter for a GEMM argument\nstruct NoGather\n{\n  template<class... Ts>\n  NoGather(Ts...) {};\n};\n\n/// Function object that applies an index to its argument\ntemplate <class Index>\nstruct IndexedGather\n{\n  CUTE_HOST_DEVICE constexpr\n  IndexedGather(Index const *indices = {}): indices_(indices) {}\n\n  template <typename I>\n  CUTE_HOST_DEVICE constexpr\n  Index\n  operator()(I i) const { return indices_[i]; }\n\n  CUTE_HOST_DEVICE friend\n  void\n  print(IndexedGather const &s) {\n    cute::print(\"Indexed\");\n  }\n\n  Index const *indices_;\n};\n\n/// Function object that applies a stride to its argument\n/// Example: StridedFunc<int,_2> gathers every other row/column\ntemplate <class Stride>\nstruct StridedGather\n{\n  CUTE_HOST_DEVICE constexpr\n  StridedGather(Stride stride = {}): stride_(stride) {}\n\n  template <class I>\n  CUTE_HOST_DEVICE constexpr\n  auto\n  operator()(I i) const { return i * stride_; }\n\n  CUTE_HOST_DEVICE friend\n  void\n  print(StridedGather const &s) {\n    cute::print(\"Strided{\");\n    print(s.stride_);\n    cute::print(\"}\");\n  }\n\n  Stride stride_;\n};\n\n/// Custom stride object that applies a function followed by a stride\ntemplate <class Func, class Stride>\nstruct CustomStride\n{\n  CUTE_HOST_DEVICE constexpr\n  CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {}\n\n  template <class I>\n  CUTE_HOST_DEVICE constexpr friend\n  auto\n  operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; }\n\n  template <class I>\n  CUTE_HOST_DEVICE constexpr friend\n  auto\n  operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; }\n\n  CUTE_HOST_DEVICE friend\n  void\n  print(CustomStride const & s) {\n    cute::print(\"Custom{\");\n    print(s.func_);\n    cute::print(\",\");\n    print(s.stride_);\n    cute::print(\"}\");\n  }\n\n  template<class Div>\n  CUTE_HOST_DEVICE constexpr friend\n  auto\n  safe_div(CustomStride const &s, Div const &div)\n  {\n    return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));\n  }\n\n  // Circumvent the requirement on make_layout that shape and stride are integral\n  template <class Shape>\n  CUTE_HOST_DEVICE constexpr friend\n  auto\n  make_layout(Shape const &shape, CustomStride const &stride)\n  {\n    return Layout<Shape, CustomStride>(shape, stride);\n  }\n\n  Func func_;\n  Stride stride_;\n};\n\ntemplate<class Stride, class Func>\nCUTLASS_HOST_DEVICE\nauto\nmake_custom_stride_layout(Stride const &stride, Func&& func)\n{\n  // Use a dummy shape and replace the first non-unit stride with a custom gather stride\n  auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });\n  constexpr int I = decltype(idx)::value;\n  return make_layout(repeat_like(stride, _1{}),\n                     replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));\n}\n\n/// Helper function to optionally create a gather tensor\ntemplate<class Iterator, class Shape, class Stride, class Func>\nCUTLASS_HOST_DEVICE\nauto\nmake_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)\n{\n  if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {\n    Layout matrix_layout = make_identity_layout(shape);\n    auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));\n    Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));\n    return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});\n  } else {\n    return make_tensor(iter, shape, stride);\n  }\n}\n\n} // namespace example\n\nnamespace cute\n{\n\ntemplate<int N, int I, class Shape, class Stride>\nCUTE_HOST_DEVICE constexpr\nauto\nupcast(Shape const& shape, Stride const& stride)\n{\n  if constexpr (is_tuple<Shape>::value) {\n    return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });\n  } else if constexpr (is_scaled_basis<Stride>::value) {\n    if constexpr (Stride::mode() == I) {\n      return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));\n    } else {\n      return make_layout(shape, stride);\n    }\n  } else {\n    return upcast<N>(shape, stride);\n  }\n\n  CUTE_GCC_UNREACHABLE;\n}\n\ntemplate <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>\nCUTE_HOST_DEVICE constexpr\nauto\nupcast(ComposedLayout<Layout<OuterShape,OuterStride>,Offset,Layout<Shape,Stride>> const& layout)\n{\n  // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset\n  auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; });\n  constexpr int I = decltype(idx)::value;\n\n  // Upcast the outer layout (works as expected)\n  auto outer = upcast<N>(layout.layout_a());\n\n  // Upcast the accumulated offset along stride-1 mode\n  auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));\n\n  // Upcast the inner layout's shape along stride-1 mode\n  auto inner = upcast<N,I>(layout.layout_b().shape(), layout.layout_b().stride());\n\n  return composition(outer, offset, inner);\n}\n\n} // namespace example\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/common/helper.h",
    "content": "/***************************************************************************************************\n * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n #pragma once\n\n #include \"cuda_runtime.h\"\n #include <iostream>\n \n /**\n  * Panic wrapper for unwinding CUTLASS errors\n  */\n #define CUTLASS_CHECK(status)                                                                    \\\n   {                                                                                              \\\n     cutlass::Status error = status;                                                              \\\n     if (error != cutlass::Status::kSuccess) {                                                    \\\n       std::cerr << \"Got cutlass error: \" << cutlassGetStatusString(error) << \" at: \" << __LINE__ \\\n                 << std::endl;                                                                    \\\n       exit(EXIT_FAILURE);                                                                        \\\n     }                                                                                            \\\n   }\n \n \n /**\n  * Panic wrapper for unwinding CUDA runtime errors\n  */\n #define CUDA_CHECK(status)                                              \\\n   {                                                                     \\\n     cudaError_t error = status;                                         \\\n     if (error != cudaSuccess) {                                         \\\n       std::cerr << \"Got bad cuda status: \" << cudaGetErrorString(error) \\\n                 << \" at line: \" << __LINE__ << std::endl;               \\\n       exit(EXIT_FAILURE);                                               \\\n     }                                                                   \\\n   }\n \n   \n#define FLASH_MLA_ASSERT(cond) \\\ndo { \\\n  if (!(cond)) { \\\n    std::cerr << \"FLASH_MLA_ASSERT: \" << #cond << \" failed at \" << __FILE__ << \":\" << __LINE__ << std::endl; \\\n    std::abort(); \\\n  } \\\n} while (0)\n\n "
  },
  {
    "path": "csrc/sm100/prefill/dense/common/mask.cuh",
    "content": "#pragma once\n\nenum class MaskMode {\n  kNone = 0U,    // No mask\n  kCausal = 1U,  // Causal mask\n  kCustom = 2U,  // Custom mask\n};\n\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/common/pipeline_mla.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n/*!\n  \\file\n  \\brief Support the producer to acquire specific bytes of data.\n*/\n\n#pragma once\n\n#include \"cutlass/pipeline/sm100_pipeline.hpp\"\n\nnamespace cutlass {\n\nusing namespace cute;\n\ntemplate <\n  int Stages_,\n  class ClusterShape = Shape<int,int,_1>,\n  class AtomThrShape_MNK_ = Shape<_1,_1,_1>\n>\nclass PipelineTmaAsyncMla {\n\npublic:\n  static constexpr uint32_t Stages = Stages_;\n  using AtomThrShape_MNK = AtomThrShape_MNK_;\n\nprivate:\n  using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;\n\npublic:\n  using FullBarrier  = typename Impl::FullBarrier;\n  using EmptyBarrier = typename Impl::EmptyBarrier;\n  using ProducerBarrierType = typename Impl::ProducerBarrierType;\n  using ConsumerBarrierType = typename Impl::ConsumerBarrierType;\n  using PipelineState = typename Impl::PipelineState;\n  using SharedStorage = typename Impl::SharedStorage;\n  using ThreadCategory = typename Impl::ThreadCategory;\n  using Params = typename Impl::Params;\n\n\n  using McastDirection = McastDirection;\n\n  // Helper function to initialize barriers\n  static\n  CUTLASS_DEVICE\n  void\n  init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {\n    int warp_idx = canonical_warp_idx_sync();\n    if (warp_idx == params.initializing_warp) {\n      // Barrier FULL and EMPTY init\n      constexpr int producer_arv_cnt = 1;\n      auto atom_thr_shape = AtomThrShape_MNK{};\n      uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +\n                                     (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;\n\n      cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(\n          storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);\n    }\n    cutlass::arch::fence_barrier_init();\n  }\n\n  static\n  CUTLASS_DEVICE\n  void\n  init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {\n    auto atom_thr_shape = AtomThrShape_MNK{};\n\n    int warp_idx = canonical_warp_idx_sync();\n    if (warp_idx == params.initializing_warp) {\n      // Barrier FULL and EMPTY init\n      constexpr int producer_arv_cnt = 1;\n      uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?\n        cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas\n        cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape);  // Mcast with col ctas\n\n      cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(\n          storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);\n    }\n    cutlass::arch::fence_barrier_init();\n  }\n\n  CUTLASS_DEVICE\n  void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {\n    // Calculate consumer mask\n    if (params_.role == ThreadCategory::Consumer) {\n      auto cluster_layout = make_layout(cluster_shape);\n      block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);\n    }\n  }\n\n  CUTLASS_DEVICE\n  void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {\n    // Calculate consumer mask\n    dim3 block_id_in_cluster = cute::block_id_in_cluster();\n    auto cluster_layout = make_layout(cluster_shape);\n    if (mcast_direction == McastDirection::kRow) {\n      block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);\n    }\n    else {\n      block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);\n    }\n  }\n\n\npublic:\n  template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>\n  CUTLASS_DEVICE\n  PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})\n      : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})\n      , params_(params)\n      , empty_barrier_ptr_(&storage.empty_barrier_[0])\n      , full_barrier_ptr_(&storage.full_barrier_[0]) {\n        static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);\n        if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {\n          init_barriers(storage, params_, cluster_shape);\n        }\n\n        static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);\n        if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {\n          init_masks(cluster_shape);\n        }\n  }\n\n  template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>\n  CUTLASS_DEVICE\n  PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})\n      : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})\n      , params_(params)\n      , empty_barrier_ptr_(&storage.empty_barrier_[0])\n      , full_barrier_ptr_(&storage.full_barrier_[0]) {\n    static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);\n    if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {\n      init_barriers(storage, params_, cluster_shape, mcast_direction);\n    }\n\n    static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);\n    if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {\n      init_masks(cluster_shape, mcast_direction);\n    }\n  }\n\n\n  CUTLASS_DEVICE\n  void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {\n    impl_.producer_acquire(state, barrier_token);\n  }\n\n  CUTLASS_DEVICE\n  void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {\n    detail::pipeline_check_is_producer(params_.role);\n    if (barrier_token != BarrierStatus::WaitDone) {\n      empty_barrier_ptr_[stage].wait(phase);\n    }\n\n    if (params_.is_leader) {\n      full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);\n    }\n    #ifndef NDEBUG\n    if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {\n      asm volatile (\"brkpt;\\n\" ::);\n    }\n\n    // Most likely you have elected more than one leader\n    if (params_.is_leader && (threadIdx.x % 32 != 0)) {\n      asm volatile (\"brkpt;\\n\" ::);\n    }\n    #endif\n  }\n\n  CUTLASS_DEVICE\n  void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {\n    producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);\n  }\n\n  CUTLASS_DEVICE\n  ProducerBarrierType* producer_get_barrier(PipelineState state) {\n    return impl_.producer_get_barrier(state);\n  }\n\n  CUTLASS_DEVICE\n  void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {\n    impl_.consumer_wait(state, barrier_token);\n  }\n\n  CUTLASS_DEVICE\n  void consumer_release(PipelineState state) {\n    consumer_release(state.index(), false);\n  }\n\nprivate:\n  Impl impl_;\n  Params params_;\n  EmptyBarrier *empty_barrier_ptr_;\n  FullBarrier *full_barrier_ptr_;\n  uint16_t block_id_mask_ = 0;\n  static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;\n\n  // Consumer signalling Producer of completion\n  // Ensures all blocks in the Same Row and Column get notifed.\n  CUTLASS_DEVICE\n  void consumer_release(uint32_t stage, uint32_t skip) {\n    detail::pipeline_check_is_consumer(params_.role);\n    uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);\n    if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1\n      if (!skip) {\n        cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);\n      }\n    }\n    else {\n      if (!skip) {\n        if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {\n          cutlass::arch::umma_arrive(smem_ptr);\n        }\n        else {\n          cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);\n        }\n      }\n    }\n  }\n};\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/common/pow_2.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n#pragma once\n\n#include <cute/config.hpp>\n#include <cute/numeric/integral_constant.hpp>\n\n#include <cuda_runtime.h>\n\nnamespace cutlass::fmha {\n\nstruct Pow2 {                                                                   \n  int n;                                                                        \n  int log2_n;                                                                   \n                                                                                \n  explicit CUTE_DEVICE Pow2(int n) : n(n) {\n#ifdef __CUDA_ARCH__\n    log2_n = __ffs(n) - 1;\n#endif\n  }                    \n                                                                                \n  template<class T>  \n  CUTE_HOST_DEVICE T operator *(T const& b) const {\n    return n * b;\n  }\n\n  template<int N>\n  CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {\n    if constexpr (N & (N - 1) == 0) {\n      return Pow2{n * N};\n    }\n    return n * N;\n  }\n\n};                                                                              \n\ntemplate<class T>\nCUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {\n  return a >> b.log2_n;\n}\n\ntemplate<class T>\nCUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {\n  return a & (b.n - 1);\n}\n\ntemplate<class T>\nCUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {\n  return a < b.n;\n}\n\nCUTE_HOST_DEVICE void print(Pow2 const& a) {\n  printf(\"2^%d\", a.log2_n);\n}\n\n} // end namespace cutlass::fmha\n\nnamespace cute {\n\ntemplate <>\nstruct is_integral<cutlass::fmha::Pow2> : true_type {};\n\n} // end namespace cute\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/common/utils.hpp",
    "content": "#pragma once\n\n#include <torch/extension.h>\n#include \"cutlass/numeric_types.h\"\n#include \"helper.h\"\n\ntemplate <typename T>\nstruct cutlass_dtype {\n  using type = T;\n};\n\ntemplate <>\nstruct cutlass_dtype<half> {\n  using type = cutlass::half_t;\n};\n\ntemplate <>\nstruct cutlass_dtype<nv_bfloat16> {\n  using type = cutlass::bfloat16_t;\n};\n\ntemplate <>\nstruct cutlass_dtype<__nv_fp8_e4m3> {\n  using type = cutlass::float_e4m3_t;\n};\n\ntemplate <>\nstruct cutlass_dtype<__nv_fp8_e5m2> {\n  using type = cutlass::float_e5m2_t;\n};\n\ntemplate <typename T>\nusing cutlass_dtype_t = typename cutlass_dtype<T>::type;"
  },
  {
    "path": "csrc/sm100/prefill/dense/device/fmha.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n/*!\n  \\file\n  \\brief An universal device layer for cutlass 3.x-style kernels.\n*/\n\n#pragma once\n\n// common\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/device_kernel.h\"\n\n#if !defined(__CUDACC_RTC__)\n#include \"cutlass/cluster_launch.hpp\"\n#include \"cutlass/trace.h\"\n#endif // !defined(__CUDACC_RTC__)\n\n////////////////////////////////////////////////////////////////////////////////\n\nnamespace cutlass::fmha::device {\n\n////////////////////////////////////////////////////////////////////////////////\n////////////////////////////// CUTLASS 3.x API /////////////////////////////////\n////////////////////////////////////////////////////////////////////////////////\n\ntemplate <class Kernel_>\nclass FMHA {\npublic:\n  using Kernel = Kernel_;\n\n  static int const kThreadCount = Kernel::MaxThreadsPerBlock;\n\n  /// Argument structure: User API\n  using Arguments = typename Kernel::Arguments;\n  /// Argument structure: Kernel API\n  using Params = typename Kernel::Params;\n\nprivate:\n\n  /// Kernel API parameters object\n  Params params_;\n\n  bool is_initialized(bool set = false) {\n    static bool initialized = false;\n    if (set) initialized = true;\n    return initialized;\n  }\n\npublic:\n\n  /// Access the Params structure\n  Params const& params() const {\n    return params_;\n  }\n\n  /// Determines whether the GEMM can execute the given problem.\n  static Status\n  can_implement(Arguments const& args) {\n    if (Kernel::can_implement(args)) {\n      return Status::kSuccess;\n    }\n    else {\n      return Status::kInvalid;\n    }\n  }\n\n  /// Gets the workspace size\n  static size_t\n  get_workspace_size(Arguments const& args) {\n    size_t workspace_bytes = 0;\n    workspace_bytes += Kernel::get_workspace_size(args);\n    return workspace_bytes;\n  }\n\n  /// Computes the grid shape\n  static dim3\n  get_grid_shape(Params const& params) {\n    return Kernel::get_grid_shape(params);\n  }\n\n  /// Computes the maximum number of active blocks per multiprocessor\n  static int maximum_active_blocks(int /* smem_capacity */ = -1) {\n    CUTLASS_TRACE_HOST(\"FMHA::maximum_active_blocks()\");\n    int max_active_blocks = -1;\n    int smem_size = Kernel::SharedStorageSize;\n\n    // first, account for dynamic smem capacity if needed\n    cudaError_t result;\n    if (smem_size >= (48 << 10)) {\n      CUTLASS_TRACE_HOST(\"  Setting smem size to \" << smem_size);\n      result = cudaFuncSetAttribute(\n          device_kernel<Kernel>,\n          cudaFuncAttributeMaxDynamicSharedMemorySize,\n          smem_size);\n      if (cudaSuccess != result) {\n        result = cudaGetLastError(); // to clear the error bit\n        CUTLASS_TRACE_HOST(\n          \"  cudaFuncSetAttribute() returned error: \"\n          << cudaGetErrorString(result));\n        return -1;\n      }\n    }\n\n    // query occupancy after setting smem size\n    result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        device_kernel<Kernel>,\n        Kernel::MaxThreadsPerBlock,\n        smem_size);\n\n    if (cudaSuccess != result) {\n      result = cudaGetLastError(); // to clear the error bit\n      CUTLASS_TRACE_HOST(\n        \"  cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: \"\n        << cudaGetErrorString(result));\n      return -1;\n    }\n\n    CUTLASS_TRACE_HOST(\"  max_active_blocks: \" << max_active_blocks);\n    return max_active_blocks;\n  }\n\n  /// Initializes GEMM state from arguments.\n  Status\n  initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {\n    CUTLASS_TRACE_HOST(\"FMHA::initialize() - workspace \"\n      << workspace << \", stream: \" << (stream ? \"non-null\" : \"null\"));\n\n    // Initialize the workspace\n    Status status = Kernel::initialize_workspace(args, workspace, stream);\n    if (status != Status::kSuccess) {\n      return status;\n    }\n\n    // Initialize the Params structure\n    params_ = Kernel::to_underlying_arguments(args, workspace);\n\n    if (is_initialized()) return Status::kSuccess;\n\n    // account for dynamic smem capacity if needed\n    int smem_size = Kernel::SharedStorageSize;\n    if (smem_size >= (48 << 10)) {\n      CUTLASS_TRACE_HOST(\"  Setting smem size to \" << smem_size);\n      cudaError_t result = cudaFuncSetAttribute(\n          device_kernel<Kernel>,\n          cudaFuncAttributeMaxDynamicSharedMemorySize,\n          smem_size);\n      if (cudaSuccess != result) {\n        result = cudaGetLastError(); // to clear the error bit\n        CUTLASS_TRACE_HOST(\"  cudaFuncSetAttribute() returned error: \" << cudaGetErrorString(result));\n        return Status::kErrorInternal;\n      }\n    }\n\n    is_initialized(true);\n\n    return Status::kSuccess;\n  }\n\n  /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.\n  Status\n  update(Arguments const& args, void* workspace = nullptr) {\n    CUTLASS_TRACE_HOST(\"FMHA()::update() - workspace: \" << workspace);\n\n    size_t workspace_bytes = get_workspace_size(args);\n    if (workspace_bytes > 0 && nullptr == workspace) {\n      return Status::kErrorWorkspaceNull;\n    }\n\n    params_ = Kernel::to_underlying_arguments(args, workspace);\n    return Status::kSuccess;\n  }\n\n  /// Primary run() entry point API that is static allowing users to create and manage their own params.\n  /// Supplied params struct must be construct by calling Kernel::to_underling_arguments()\n  static Status\n  run(Params& params, cudaStream_t stream = nullptr) {\n    CUTLASS_TRACE_HOST(\"FMHA::run()\");\n    dim3 const block = Kernel::get_block_shape();\n    dim3 const grid = get_grid_shape(params);\n\n    // No need to launch the kernel\n    if(grid.x == 0 || grid.y == 0 || grid.z == 0) {\n      return Status::kSuccess; \n    }\n\n    // configure smem size and carveout\n    int smem_size = Kernel::SharedStorageSize;\n\n    Status launch_result;\n    // Use extended launch API only for mainloops that use it\n    if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {\n      dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),\n                   cute::size<1>(typename Kernel::ClusterShape{}),\n                   cute::size<2>(typename Kernel::ClusterShape{}));\n      void const* kernel = (void const*) device_kernel<Kernel>;\n      void* kernel_params[] = {&params};\n      launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);\n    }\n    else {\n      launch_result = Status::kSuccess;\n      device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);\n    }\n\n    cudaError_t result = cudaGetLastError();\n    if (cudaSuccess == result && Status::kSuccess == launch_result) {\n      return Status::kSuccess;\n    }\n    else {\n      CUTLASS_TRACE_HOST(\"  Kernel launch failed. Reason: \" << result);\n      return Status::kErrorInternal;\n    }\n  }\n\n  //\n  // Non-static launch overloads that first create and set the internal params struct of this kernel handle.\n  //\n\n  /// Launches the kernel after first constructing Params internal state from supplied arguments.\n  Status\n  run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {\n    Status status = initialize(args, workspace, stream);\n    if (Status::kSuccess == status) {\n      status = run(params_, stream);\n    }\n    return status;\n  }\n\n  /// Launches the kernel after first constructing Params internal state from supplied arguments.\n  Status\n  operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {\n    return run(args, workspace, stream);\n  }\n\n  /// Overload that allows a user to re-launch the same kernel without updating internal params struct.\n  Status\n  run(cudaStream_t stream = nullptr) {\n    return run(params_, stream);\n  }\n\n  /// Overload that allows a user to re-launch the same kernel without updating internal params struct.\n  Status\n  operator()(cudaStream_t stream = nullptr) {\n    return run(params_, stream);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////\n\n} // namespace cutlass::device\n\n////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n// common\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/kernel_hardware_info.hpp\"\n#include \"cute/tensor.hpp\"\n\n#include \"../device/fmha.hpp\"\n#include \"../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp\"\n#include \"../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp\"\n#include \"../kernel/fmha_kernel_bwd_sum_OdO.hpp\"\n#include \"../kernel/fmha_kernel_bwd_convert.hpp\"\n\n////////////////////////////////////////////////////////////////////////////////\n\nnamespace cutlass::fmha::device {\n\n////////////////////////////////////////////////////////////////////////////////\n////////////////////////////// CUTLASS 3.x API /////////////////////////////////\n////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    class ProblemShape,\n    class Element,\n    class ElementAccumulator,\n    class TileShape,\n    bool IsMla,\n    class Mask\n>\nclass Sm100FmhaBwd {\npublic:\n  /// Argument structure: User API\n  struct Arguments {\n    // Q K D D_VO HB\n    ProblemShape problem_shape;\n\n    const Element* ptr_Q;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;\n    const Element* ptr_K;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;\n    const Element* ptr_V;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;\n\n    const Element* ptr_O;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;\n    const ElementAccumulator* ptr_LSE;\n    cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;\n\n    const Element* ptr_dO;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;\n\n    Element* ptr_dQ;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;\n    Element* ptr_dK;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;\n    Element* ptr_dV;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;\n\n    ElementAccumulator softmax_scale;\n\n    cutlass::KernelHardwareInfo hw_info;\n  };\n\n  using OperationSumOdO = cutlass::fmha::device::FMHA<\n    cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>\n  >;\n  using OperationConvert = cutlass::fmha::device::FMHA<\n    cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>\n  >;\n\n  using OperationMha= cutlass::fmha::device::FMHA<\n      cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<\n          ProblemShape, Element, ElementAccumulator, TileShape, Mask\n      >\n  >;\n\n  using OperationMla = cutlass::fmha::device::FMHA<\n      cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<\n          ProblemShape, Element, ElementAccumulator, TileShape, Mask\n      >\n  >;\n\n  using Operation = std::conditional_t<IsMla, OperationMla, OperationMha>;\n\n  using Kernel = typename Operation::Kernel;\n\n  struct Params {\n    OperationSumOdO op_sum_OdO;\n    Operation op;\n    OperationConvert op_convert;\n    ElementAccumulator* dQ_acc;\n    size_t dQ_acc_size;\n  };\n\nprivate:\n  Params params_;\n\n  static typename OperationSumOdO::Arguments to_sum_OdO_arguments(\n        Arguments const& args,\n        ElementAccumulator* sum_odo = nullptr,\n        ElementAccumulator* scaled_lse = nullptr) {\n    using namespace cute;\n    auto [Q_, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    D = cutlass::round_up(D, 8);  // Alignment\n    int Q = cutlass::round_up(static_cast<int>(Q_), 8);  // Alignment\n    auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));\n    auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));\n    auto log2_e = log2f(expf(1.0f));\n    return typename OperationSumOdO::Arguments {\n      args.problem_shape,\n      args.ptr_O, args.stride_O,\n      args.ptr_dO, args.stride_dO,\n      sum_odo, stride_sum_OdO,\n      args.ptr_LSE, args.stride_LSE,\n      scaled_lse, stride_scaled_lse,\n      -1.0f, -log2_e\n    };\n  }\n\n  static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {\n    using namespace cute;\n    auto [Q_, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    D = cutlass::round_up(D, 8);  // Alignment\n    int Q = cutlass::round_up(static_cast<int>(Q_), 8);  // Alignment\n    auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));\n    return typename OperationConvert::Arguments {\n      args.problem_shape,\n      src, stride_src_dQ,\n      nullptr, stride_src_dQ,\n      nullptr, stride_src_dQ,\n      args.ptr_dQ, args.stride_dQ,\n      nullptr, args.stride_dK,\n      nullptr, args.stride_dV,\n      args.softmax_scale\n    };\n  }\n\n  static typename Operation::Arguments to_bwd_arguments(\n      Arguments const& args,\n      ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},\n      ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},\n      ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {\n\n    return typename Operation::Arguments{\n      args.problem_shape,\n      { args.ptr_Q,  args.stride_Q,\n        args.ptr_K,  args.stride_K,\n        args.ptr_V,  args.stride_V,\n        args.ptr_dO, args.stride_dO,\n        scaled_lse, stride_scaled_lse,\n        sum_OdO, stride_sum_OdO,\n        dQ_acc, stride_dQ,\n        args.softmax_scale },\n      { args.ptr_dK, args.stride_dK,\n        args.ptr_dV, args.stride_dV },\n      args.hw_info\n    };\n  }\n\npublic:\n\n  /// Determines whether the GEMM can execute the given problem.\n  static Status\n  can_implement(Arguments const& args) {\n    Status status = Status::kSuccess;\n\n    status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));\n    if (status != Status::kSuccess) {\n      return status;\n    }\n\n    status = OperationConvert::can_implement(to_convert_arguments(args));\n    if (status != Status::kSuccess) {\n      return status;\n    }\n\n    status = Operation::can_implement(to_bwd_arguments(args));\n    if (status != Status::kSuccess) {\n      return status;\n    }\n\n    return status;\n  }\n\n  /// Gets the workspace size\n  static size_t\n  get_workspace_size(Arguments const& args) {\n    auto [Q_, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    D = cutlass::round_up(D, 8);  // Alignment\n    int Q = cutlass::round_up(static_cast<int>(Q_), 8);  // Alignment\n    size_t workspace_bytes = 0;\n    // OdO vector\n    workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;\n    // scaled LSE vector\n    workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;\n    // FP32 versions of outputs that are churned (start off with Q only)\n    workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D;\n    return workspace_bytes;\n  }\n\n  /// Initializes state from arguments.\n  Status\n  initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {\n    CUTLASS_TRACE_HOST(\"Universal::initialize_split() - workspace_dQ=\"\n      << workspace_dQ << \", workspace_sum_OdO=\" << workspace_sum_OdO << \"stream: \" << (stream ? \"non-null\" : \"null\"));\n\n    auto [Q_, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    D = cutlass::round_up(D, 8);  // Alignment\n    int Q = cutlass::round_up(static_cast<int>(Q_), 8);  // Alignment\n    ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);\n    ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);\n    ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);\n    params_.dQ_acc = dQ_acc;\n    params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D;\n    auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);\n    auto args_convert = to_convert_arguments(args, dQ_acc);\n    params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);\n    params_.op_convert.initialize(args_convert, nullptr, stream);\n    auto args_bwd = to_bwd_arguments(\n        args, sum_OdO, args_sum_OdO.stride_sum_OdO,\n        scaled_lse, args_sum_OdO.stride_scaled_lse,\n        dQ_acc, args_convert.stride_src_dQ\n    );\n    params_.op.initialize(args_bwd, nullptr, stream);\n\n    return Status::kSuccess;\n  }\n\n  /// Initializes state from arguments.\n  Status\n  initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {\n    CUTLASS_TRACE_HOST(\"Universal::initialize() - workspace \"\n      << workspace << \", stream: \" << (stream ? \"non-null\" : \"null\"));\n\n    auto [Q_, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    D = cutlass::round_up(D, 8);  // Alignment\n    int Q = cutlass::round_up(static_cast<int>(Q_), 8);  // Alignment\n    char* workspace_chr = reinterpret_cast<char*>(workspace);\n    ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);\n    workspace_chr += sizeof(ElementAccumulator) * B*H*Q;\n    ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);\n    workspace_chr += sizeof(ElementAccumulator) * B*H*Q;\n    ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);\n    return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);\n  }\n\n  /// Primary run() entry point API that is static allowing users to create and manage their own params.\n  /// Supplied params struct must be construct by calling Kernel::to_underling_arguments()\n  static Status\n  run(Params& params, cudaStream_t stream = nullptr) {\n    CUTLASS_TRACE_HOST(\"FmhaDeviceBwd::run()\");\n\n    Status result = Status::kSuccess;\n    result = params.op_sum_OdO.run(stream);\n    if (result != Status::kSuccess) {\n      return result;\n    }\n\n    auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);\n    if (cuda_result != cudaSuccess) {\n       return Status::kErrorInternal;\n    }\n\n    result = params.op.run(stream);\n    if (result != Status::kSuccess) {\n      return result;\n    }\n\n    result = params.op_convert.run(stream);\n    if (result != Status::kSuccess) {\n      return result;\n    }\n\n    return Status::kSuccess;\n  }\n\n  //\n  // Non-static launch overloads that first create and set the internal params struct of this kernel handle.\n  //\n\n  /// Launches the kernel after first constructing Params internal state from supplied arguments.\n  Status\n  run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {\n    Status status = initialize(args, workspace, stream);\n    if (Status::kSuccess == status) {\n      status = run(params_, stream);\n    }\n    return status;\n  }\n\n  /// Overload that allows a user to re-launch the same kernel without updating internal params struct.\n  Status\n  run(cudaStream_t stream = nullptr) {\n    return run(params_, stream);\n  }\n\n};\n\n////////////////////////////////////////////////////////////////////////////////\n\n} // namespace cutlass::fmha::device\n\n////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
    "content": "#include \"interface.h\"\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_bf16.h>\n#include \"common/mask.cuh\"\n#include \"common/utils.hpp\"\n\n#include \"fmha_cutlass_bwd_sm100.cuh\"\n\ntemplate<class Mask, class Varlen, class Element, class ElementOut, class Mla>\nvoid call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,\n                      [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla,\n                  at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,\n                  at::Tensor v, at::Tensor o, at::Tensor lse,\n                  at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                  at::Tensor dq, at::Tensor dk, at::Tensor dv,\n                  float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {\n  static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;\n  static constexpr bool IsMla = std::is_same_v<Mla, true_type>;\n  using TileShape = std::conditional_t<IsMla, Shape<_64, _128, _192, _128>, Shape<_128, _128, _128, _128>>;\n  run_fmha_bwd<Element, IsVarlen, IsMla, TileShape, Mask>(workspace_buffer, d_o, q, k, v, o, lse,\n                          cumulative_seqlen_q, cumulative_seqlen_kv,\n                          dq, dk, dv,\n                          softmax_scale, max_seqlen_q, total_seqlen_kv);\n}\n\n\nvoid FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,\n                            at::Tensor v, at::Tensor o, at::Tensor lse,\n                            at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                            at::Tensor dq, at::Tensor dk, at::Tensor dv,\n                            int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) {\n\n  const c10::cuda::OptionalCUDAGuard device_guard(q.device());\n\n  int head_dim_qk = q.size(-1);\n  int head_dim_vo = v.size(-1);\n  MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);\n  auto scalar_type_in = q.scalar_type();\n  auto scalar_type_out = o.scalar_type();\n\n  if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) {\n    using Element = cutlass::bfloat16_t;\n    using ElementOut = cutlass::bfloat16_t;\n\n    auto apply_config = [&](auto fn) {\n      if (mask_mode == MaskMode::kCausal) {\n        if(is_varlen) {\n          fn(CausalForBackwardMask<false>{}, cute::true_type{}, Element{}, ElementOut{});\n        } else {\n          fn(CausalForBackwardMask<false>{}, cute::false_type{}, Element{}, ElementOut{});\n        }\n      }\n      else {\n        if(is_varlen) {\n          fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{});\n        } else {\n          fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{});\n        }\n      }\n    };\n\n    apply_config([&](auto mask, auto varlen, auto in, auto out) {\n      if (head_dim_qk == 192 && head_dim_vo == 128) {\n        call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse,\n                          cumulative_seqlen_q, cumulative_seqlen_kv,\n                          dq, dk, dv,\n                          softmax_scale, max_seqlen_q, max_seqlen_kv);\n      } else if (head_dim_qk == 128 && head_dim_vo == 128) {\n        call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse,\n                          cumulative_seqlen_q, cumulative_seqlen_kv,\n                          dq, dk, dv,\n                          softmax_scale, max_seqlen_q, max_seqlen_kv);      }\n      else {\n        std::cout << \"No kernel instantiated for head_dim_qk=\" << head_dim_qk << \" head_dim_vo=\" << head_dim_vo << std::endl;\n      }\n    });\n\n  } else {\n    FLASH_MLA_ASSERT(false);\n  }\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include <iostream>\n#include <random>\n#include <regex>\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/kernel_hardware_info.h>\n\n#include <cutlass/util/command_line.h>\n#include <cutlass/util/distribution.h>\n#include <cutlass/util/reference/device/tensor_fill.h>\n\n#include \"common/utils.hpp\"\n#include \"collective/fmha_fusion.hpp\"\n#include \"device/fmha_device_bwd.hpp\"\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n\nusing namespace cute;\nusing namespace cutlass::fmha::kernel;\nusing namespace cutlass::fmha::collective;\nusing namespace cutlass::fmha;\nusing namespace cutlass;\n\n\ntemplate<\n  class DType,\n  bool kIsVarlen,\n  bool kIsMla,\n  class TileShape,\n  class ActiveMask\n>\nstruct BwdRunner {\n\n  using Element = DType;\n  using ElementAccumulator = float;\n\n  // Q K D D_VO (H B)\n  using ProblemShape = std::conditional_t<\n    kIsVarlen,\n    cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,\n    cute::tuple<int, int, int, int, cute::tuple<int, int>>\n  >;\n\n  using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;\n  \n  using TensorStride = Stride<int, _1, Stride<int, int>>; \n  using StrideQ = TensorStride;                               // Seq DQK (H B)\n  using StrideK = TensorStride;                               // Seq DQK (H B)\n  using StrideV = TensorStride;                               // Seq DVO (H B)\n  using StrideO = TensorStride;                               // Seq DVO (H B)\n  using StrideLSE = Stride<_1, Stride<int, int>>;             // Seq (H B)\n\n  // Backwards specific\n  using StrideDQ = TensorStride;\n  using StrideDK = TensorStride;                              // Seq DQK (H B)\n  using StrideDV = TensorStride;                              // Seq DVO (H B)\n  using StrideDO = TensorStride;\n\n  static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,\n                  at::Tensor v, at::Tensor o, at::Tensor lse,\n                  at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                  at::Tensor dq, at::Tensor dk, at::Tensor dv,\n                  float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {\n    const at::cuda::CUDAGuard device_guard{(char)q.get_device()};\n    const int device_id = q.get_device();\n\n    cutlass::KernelHardwareInfo hw_info;\n    hw_info.device_id =device_id;\n    hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n    ProblemShape problem_shape;\n    cute::tuple<int, int, int, int, cute::tuple<int, int>> tensor_shape;\n\n\n    int d = q.size(-1);\n    int d_vo = v.size(-1);\n    int batch_size = cumulative_seqlen_q.size(0) - 1;\n    int num_qo_heads = q.size(1);\n    int total_seqlen_q = q.size(0);\n    int total_seqlen_kv = k.size(0);\n    \n    //varlen: q: [Q, H, D]\n    //fixedlen: q: [B, H, Q, D] \n    if constexpr (kIsVarlen) {\n      problem_shape = cute::make_tuple(\n        VariableLength{max_seqlen_q, static_cast<int*>(cumulative_seqlen_q.data_ptr()), total_seqlen_q},\n        VariableLength{max_seqlen_kv, static_cast<int*>(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv},\n        d, d_vo, cute::make_tuple(num_qo_heads, batch_size));\n      tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1));\n    } else {\n      int q_len = total_seqlen_q / batch_size;\n      int kv_len = total_seqlen_kv / batch_size;\n      problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size));\n      tensor_shape = problem_shape;\n    }\n\n    auto [Q, K, D, D_VO, HB] = tensor_shape;\n    auto [H, B] = HB;\n\n    int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);\n    int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);\n    int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);\n    int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);\n    int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);\n    int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2);\n    int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2);\n    int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2);\n    int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2);\n    TORCH_CHECK(q_stride2 == 1);\n    TORCH_CHECK(k_stride2 == 1);\n    TORCH_CHECK(v_stride2 == 1);\n    TORCH_CHECK(o_stride2 == 1);\n    TORCH_CHECK(lse_stride0 == 1);\n    TORCH_CHECK(dq_stride2 == 1);\n    TORCH_CHECK(dk_stride2 == 1);\n    TORCH_CHECK(dv_stride2 == 1);\n    TORCH_CHECK(do_stride2 == 1);\n\n    StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q));\n    StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K));\n    StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K));\n    StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q));\n    StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q));\n\n    StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q));\n    StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K));\n    StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K));\n    StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q));\n\n    typename Operation::Arguments arguments{\n      problem_shape,\n      (static_cast<Element*>(q.data_ptr())), stride_Q,\n      (static_cast<Element*>(k.data_ptr())), stride_K,\n      (static_cast<Element*>(v.data_ptr())), stride_V,\n      (static_cast<Element*>(o.data_ptr())), stride_O,\n      (static_cast<ElementAccumulator*>(lse.data_ptr())), stride_LSE,\n      (static_cast<Element*>(d_o.data_ptr())), stride_dO,\n      (static_cast<Element*>(dq.data_ptr())), stride_dQ,\n      (static_cast<Element*>(dk.data_ptr())), stride_dK,\n      (static_cast<Element*>(dv.data_ptr())), stride_dV,\n      static_cast<ElementAccumulator>(softmax_scale),\n      hw_info\n    };\n\n    Operation op;\n\n    uint8_t* workspace_ptr = static_cast<uint8_t*>(workspace_buffer.data_ptr());\n\n    CUTLASS_CHECK(op.can_implement(arguments));\n    CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));\n    CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));\n  }\n\n};\n\n\ntemplate <typename DType, bool kIsVarlen, bool kIsMla, typename TileShape, typename Mask>\nvoid run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,\n                  at::Tensor v, at::Tensor o, at::Tensor lse,\n                  at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                  at::Tensor dq, at::Tensor dk, at::Tensor dv,\n                  float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {\n  BwdRunner<DType, kIsVarlen, kIsMla, TileShape, Mask>::run(workspace_buffer, d_o, q, k, v, o, lse,\n                                                     cumulative_seqlen_q, cumulative_seqlen_kv,\n                                                     dq, dk, dv,\n                                                     softmax_scale, max_seqlen_q, total_seqlen_kv);\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
    "content": "#include \"interface.h\"\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_bf16.h>\n\n#include \"common/mask.cuh\"\n#include \"common/utils.hpp\"\n\n#include \"fmha_cutlass_fwd_sm100.cuh\"\n\ntemplate <class Mask, class Varlen, class Element, class ElementOut, class Mla>\nvoid call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,\n                       [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out,\n                       [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q,\n                       at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q,\n                       at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,\n                       float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {\n  static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;\n  static constexpr bool IsMla = std::is_same_v<Mla, true_type>;\n  static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>;\n  using Option =\n      std::conditional_t<IsCausalMask || (IsVarlen), Option<Tag::kIsPersistent, false_type>,\n                         Option<Tag::kIsPersistent, true_type>>;\n\n  run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>(\n      workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse,\n      softmax_scale, max_seqlen_q, max_seqlen_kv);\n}\n\nvoid FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k,\n                            at::Tensor v, at::Tensor cumulative_seqlen_q,\n                            at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,\n                            int mask_mode_code, float sm_scale, int max_seqlen_q,\n                            int max_seqlen_kv, bool is_varlen) {\n  const c10::cuda::OptionalCUDAGuard device_guard(q.device());\n  CHECK(q.scalar_type() == k.scalar_type());\n  auto scalar_type_in = q.scalar_type();\n  auto scalar_type_out = o.scalar_type();\n  int head_dim_qk = q.size(-1);\n  int head_dim_vo = v.size(-1);\n  MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);\n\n  if (scalar_type_in == at::ScalarType::BFloat16 &&\n      scalar_type_out == at::ScalarType::BFloat16) {\n    using Element = cutlass::bfloat16_t;\n    using ElementOut = cutlass::bfloat16_t;\n\n    auto apply_config = [&](auto fn) {\n      if (mask_mode == MaskMode::kCausal) {\n        if (is_varlen) {\n          fn(CausalMask<false>{}, cute::true_type{}, Element{}, ElementOut{});\n        } else {\n          fn(CausalMask<false>{}, cute::false_type{}, Element{}, ElementOut{});\n        }\n      } else {\n        if (is_varlen) {\n          fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{});\n        } else {\n          fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{});\n        }\n      }\n    };\n\n    apply_config([&](auto mask, auto varlen, auto in, auto out) {\n      if (head_dim_qk == 192 && head_dim_vo == 128) {\n        call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v,\n                          cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,\n                          max_seqlen_q, max_seqlen_kv);\n      } else if (head_dim_qk == 128 && head_dim_vo == 128) {\n        call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v,\n                          cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,\n                          max_seqlen_q, max_seqlen_kv);\n      } else {\n        std::cout << \"No kernel instantiated for head_dim_qk=\" << head_dim_qk\n                  << \" head_dim_vo=\" << head_dim_vo << std::endl;\n      }\n    });\n\n  } else {\n    FLASH_MLA_ASSERT(false);\n  }\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh",
    "content": "#pragma once\n\n#include \"collective/fmha_fusion.hpp\"\n#include \"collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp\"\n#include \"collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp\"\n#include \"collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/kernel_hardware_info.h\"\n#include \"device/fmha.hpp\"\n#include \"kernel/fmha_causal_tile_scheduler.hpp\"\n#include \"kernel/fmha_options.hpp\"\n#include \"kernel/fmha_tile_scheduler.hpp\"\n#include \"kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp\"\n\n#include <torch/library.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n\nusing namespace cute;\nusing namespace cutlass::fmha::collective;\nusing namespace cutlass::fmha::kernel;\nusing namespace cutlass::fmha::device;\n\nstruct FmhaOptions {\n  int b = 1;\n  int h = 1;\n  int h_k = 1;\n  int q = 256;\n  int k = 256;\n  int d = 128;\n};\n\nstruct MlaOptions {\n  int b = 1;\n  int h = 1;\n  int h_k = 1;\n  int q = 256;\n  int k = 256;\n  int dl = 128; // headdim latent\n  int dr = 64;  // headdim rope\n};\n\ntemplate <bool kIsMla, bool kIsMaskTileSchedulerValid, bool kIsVarlen, class Element_,\n          class ElementOut_, class ActiveMask, class... KernelOptions>\nstruct FwdRunner {\n\n  using Element = Element_;\n  using ElementAccumulatorQK = float;\n  using ElementAccumulatorPV = float;\n  using ElementOut = ElementOut_;\n\n  using HeadDimLatent = _128;\n  using HeadDim = Shape<HeadDimLatent, _64>;\n  using TileShapeMla = Shape<_256, _128, HeadDim>;\n  using TileShapeFmha = Shape<_256, _128, _128>;\n  using TileShape = std::conditional_t<kIsMla, TileShapeMla, TileShapeFmha>;\n\n  using ProblemShapeRegular = std::conditional_t<\n      kIsMla,\n      cute::tuple<int, int, cute::tuple<int, int>, cute::tuple<cute::tuple<int, int>, int>>,\n      cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>>;\n\n  using ProblemShapeVarlen =\n      std::conditional_t<kIsMla,\n                         cute::tuple<VariableLength, VariableLength, cute::tuple<int, int>,\n                                     cute::tuple<cute::tuple<int, int>, int>>,\n                         cute::tuple<VariableLength, VariableLength, int,\n                                     cute::tuple<cute::tuple<int, int>, int>>>;\n\n  using ProblemShapeType =\n      std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;\n\n  using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>;\n  using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>;\n  using StrideV = StrideK;\n  using StrideO = StrideQ;\n  using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>;\n\n  static constexpr bool kIsPersistent =\n      find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;\n\n  using TileScheduler = std::conditional_t<\n      kIsPersistent,\n      std::conditional_t<std::is_same_v<ActiveMask, CausalMask<false>> ||\n                             std::is_same_v<ActiveMask, CausalMask<true>>,\n                         cutlass::fmha::kernel::CausalPersistentTileScheduler,\n                         cutlass::fmha::kernel::PersistentTileScheduler>,\n      std::conditional_t<kIsMaskTileSchedulerValid,\n                         cutlass::fmha::kernel::CausalIndividualTileScheduler,\n                         cutlass::fmha::kernel::IndividualTileScheduler>>;\n\n  static constexpr bool IsOrderLoadEpilogue =\n      kIsPersistent && (sizeof(Element) == sizeof(ElementOut));\n  using OrderLoadEpilogue = std::conditional_t<IsOrderLoadEpilogue, true_type, false_type>;\n\n  using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized<\n      Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK,\n      StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>;\n\n  using OperationMla =\n      cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<\n          ProblemShapeType, MainloopMla,\n          cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<\n              ElementOut, ElementAccumulatorPV, typename MainloopMla::TileShapePV, StrideO,\n              StrideLSE, OrderLoadEpilogue>,\n          TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>;\n\n  using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<\n      Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK,\n      StrideV, ActiveMask>;\n\n  using OperationFmha =\n      cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<\n          ProblemShapeType, MainloopFmha,\n          cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<\n              ElementOut, ElementAccumulatorPV, typename MainloopFmha::TileShapePV, StrideO,\n              StrideLSE>,\n          TileScheduler>>;\n\n  using Mainloop = std::conditional_t<kIsMla, MainloopMla, MainloopFmha>;\n  using Operation = std::conditional_t<kIsMla, OperationMla, OperationFmha>;\n\n  //\n  // Data members\n  //\n\n  /// Initialization\n  StrideQ stride_Q;\n  StrideK stride_K;\n  StrideV stride_V;\n  StrideO stride_O;\n  StrideLSE stride_LSE;\n\n  template <class ProblemShape>\n  auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv,\n                         int total_seqlen_q, int total_seqlen_kv) {\n\n    int num_batches = get<3, 1>(problem_size);\n\n    ProblemShape problem_size_for_init = problem_size;\n    get<3, 1>(problem_size_for_init) = 1;\n    get<0>(problem_size_for_init) = total_seqlen_q;\n    get<1>(problem_size_for_init) = total_seqlen_kv;\n\n    ProblemShapeType problem_size_for_launch;\n\n    get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};\n    get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};\n    get<2>(problem_size_for_launch) = get<2>(problem_size);\n    get<3>(problem_size_for_launch) = get<3>(problem_size);\n\n    return cute::make_tuple(problem_size_for_init, problem_size_for_launch);\n  }\n\n  template <class Options>\n  static constexpr auto get_problem_shape(const Options &options) {\n    int h_r = options.h / options.h_k;\n    if constexpr (std::is_same_v<Options, MlaOptions>) {\n      return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr),\n                              cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));\n    } else {\n      return cute::make_tuple(options.q, options.k, options.d,\n                              cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));\n    }\n  }\n\n  template <class Options>\n  ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv,\n                                   int total_seqlen_q, int total_seqlen_kv,\n                                   void *cumulative_length_q, void *cumulative_length_kv) {\n    assert(options.h % options.h_k == 0);\n    auto problem_shape_in = get_problem_shape(options);\n\n    ProblemShapeType problem_shape;\n    decltype(problem_shape_in) problem_size;\n\n    if constexpr (kIsVarlen) {\n      auto [problem_shape_init, problem_shape_launch] = initialize_varlen(\n          problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv);\n      problem_shape = problem_shape_launch;\n      problem_size = problem_shape_init;\n    } else {\n      problem_size = problem_shape_in;\n      problem_shape = problem_shape_in;\n    }\n\n    auto get_head_dimension = [&]() {\n      if constexpr (rank_v<decltype(get<2>(problem_shape))> == 2) {\n        return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape),\n                                size<2, 0>(problem_shape));\n      } else {\n        return cute::make_tuple(size<2>(problem_size), size<2>(problem_size));\n      }\n    };\n\n\n    if constexpr (kIsVarlen) {\n      get<0>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_q);\n      get<1>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_kv);\n    }\n\n    return problem_shape;\n  }\n\n  auto get_arguments(const ProblemShapeType &problem_shape,\n                     const cutlass::KernelHardwareInfo &hw_info, float scale_softmax,\n                     void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr,\n                     void *cumulative_length_q, void *cumulative_length_kv) {\n    auto problem_shape_ = problem_shape;\n\n    typename Operation::Arguments arguments{\n        problem_shape_,\n        {static_cast<Element *>(q_ptr), stride_Q, static_cast<Element *>(k_ptr), stride_K,\n         static_cast<Element *>(v_ptr), stride_V, scale_softmax},\n        {static_cast<ElementOut *>(o_ptr), stride_O,\n         static_cast<ElementAccumulatorPV *>(lse_ptr), stride_LSE},\n        hw_info};\n\n    return arguments;\n  }\n\n  template <class Options>\n  void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q,\n           at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax,\n           at::Tensor workspace, at::Tensor cumulative_seqlen_q,\n           at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) {\n\n    int total_seqlen_q = q.size(0);\n    int total_seqlen_kv = k.size(0);\n\n    ProblemShapeType problem_shape =\n        initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv,\n                        cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());\n    \n    int SQ = size<0>(problem_shape);\n    int SK = size<1>(problem_shape);\n    int B = size<3, 1>(problem_shape);\n    int H = size<3, 0>(problem_shape);\n    int H_K = size<3, 0, 1>(problem_shape);\n    int H_Q = size<3, 0, 0>(problem_shape);\n\n    int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);\n    int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);\n    int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);\n    int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);\n    int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);\n    TORCH_CHECK(q_stride2 == 1);\n    TORCH_CHECK(k_stride2 == 1);\n    TORCH_CHECK(v_stride2 == 1);\n    TORCH_CHECK(o_stride2 == 1);\n    TORCH_CHECK(lse_stride0 == 1);\n\n    stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0));\n    stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0));\n    stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0));\n    stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0));\n    stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ));\n\n    if constexpr (kIsVarlen) {\n      get<2, 1>(stride_Q) = 0;\n      get<2, 1>(stride_K) = 0;\n      get<2, 1>(stride_V) = 0;\n      get<2, 1>(stride_O) = 0;\n      get<1, 1>(stride_LSE) = 0;\n    }\n\n    typename Operation::Arguments arguments =\n        get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(),\n                      v.data_ptr(), o.data_ptr(), lse.data_ptr(),\n                      cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());\n\n    Operation op;\n\n    // size_t workspace_size = 0;\n    // workspace_size = Operation::get_workspace_size(arguments);\n\n    // todo: if use workspace, need check workspace size first.\n    // we don't use workspace in current version.\n\n    CUTLASS_CHECK(op.can_implement(arguments));\n    CUTLASS_CHECK(op.initialize(arguments, nullptr));\n    CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));\n  }\n};\n\ntemplate <class DTypeIn, class DTypeOut, bool kIsVarlen, bool kIsMla, class ActiveMask,\n          class... KernelOptions>\nvoid run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v,\n                  at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o,\n                  at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) {\n\n  const at::cuda::CUDAGuard device_guard{(char)q.get_device()};\n  const int device_id = q.get_device();\n\n  cutlass::KernelHardwareInfo hw_info;\n  hw_info.device_id = device_id;\n  hw_info.sm_count =\n      cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n\n  auto get_options = [&]() {\n    if constexpr (kIsMla) {\n      MlaOptions options;\n      options.b = cumulative_seqlen_q.size(0) - 1;\n      options.h = q.size(1);\n      options.h_k = k.size(1);\n      options.q = q.size(0) / options.b;\n      options.k = k.size(0) / options.b;\n      options.dl = v.size(-1);\n      options.dr = q.size(-1) - v.size(-1);\n      return options;\n    } else {\n      FmhaOptions options;\n      options.b = cumulative_seqlen_q.size(0) - 1;\n      options.h = q.size(1);\n      options.h_k = k.size(1);\n      options.q = q.size(0) / options.b;\n      options.k = k.size(0) / options.b;\n      options.d = q.size(-1);\n      return options;\n    }\n  };\n\n  auto options = get_options();\n\n  if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 &&\n      (std::is_same_v<ActiveMask, CausalMask<false>> || std::is_same_v<ActiveMask, CausalMask<true>>)) {\n    FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;\n    runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,\n               cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);\n  } else {\n    FwdRunner<kIsMla, false, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;\n    runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,\n               cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);\n  }\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/interface.h",
    "content": "#pragma once\n\n#include <ATen/Tensor.h>\n\nvoid FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v,\n                            at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                            at::Tensor o, at::Tensor lse,\n                            int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);\n\nvoid FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,\n                            at::Tensor v, at::Tensor o, at::Tensor lse,\n                            at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,\n                            at::Tensor dq, at::Tensor dk, at::Tensor dv,\n                            int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/fast_math.h\"\n\nnamespace cutlass::fmha::kernel {\n\n////////////////////////////////////////////////////////////////////////////////\n\n// Swizzle Q tile and H tile to improve L2 cache hit rate, \n// and launch the longest main loop first to keep most SMs busy.\n\nstruct CausalIndividualTileScheduler {\n  \n  static constexpr int TileQ = 16;\n  static constexpr int TileH = 8;\n  static constexpr int TileSize = TileQ * TileH;\n\n  struct Params {\n    dim3 grid;\n    int tile_max_q;\n    FastDivmod divmod_tile_col;\n    FastDivmod divmod_tile_size;\n    FastDivmod divmod_tile_head;\n  };\n\n  bool valid_ = true;\n  Params params;\n\n  CUTLASS_DEVICE\n  CausalIndividualTileScheduler(Params const& params) : params(params) {}\n\n  template<class ProblemSize, class ClusterShape, class TileShape>\n  static Params to_underlying_arguments(\n      ProblemSize const& problem_size, KernelHardwareInfo hw_info,\n      ClusterShape const& cluster_shape, TileShape const& tile_shape) {\n    using namespace cute;\n\n    dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));\n    // gridDim.x must multiple of TileH\n    const int tile_col_count = grid.x / TileH;\n    const int tile_max_q = grid.y / TileQ * TileQ;\n    return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    return params.grid;\n  }\n\n  CUTLASS_DEVICE\n  bool is_valid() {\n    return valid_;\n  }\n\n  CUTLASS_DEVICE\n  auto get_block_coord() {\n    using namespace cute;\n    const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;\n\n    int tile_idx, tile_tail;\n    params.divmod_tile_size(tile_idx, tile_tail, block_idx);\n\n    int tile_row_idx, tile_col_idx;\n    params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);\n\n    int row_offset_in_tail, col_offset_in_tail;\n    params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);\n\n    const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;\n    const int col_idx = tile_col_idx * TileH + col_offset_in_tail;\n    \n    // last q tile launch first\n    if(blockIdx.y >= params.tile_max_q) {\n      return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));\n    } \n\n    return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));\n  }\n\n  CUTLASS_DEVICE\n  CausalIndividualTileScheduler& operator++() {\n    valid_ = false;\n    return *this;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////\n\n\n////////////////////////////////////////////////////////////////////////////////\n\n// Launch order: H Q B\nstruct CausalPersistentTileScheduler {\n\n  struct Params {\n    int num_blocks;\n    FastDivmod divmod_h;\n    FastDivmod divmod_m_block;\n    FastDivmod divmod_b;\n\n    KernelHardwareInfo hw_info;\n  };\n\n  int block_idx = 0;\n  Params params;\n\n  CUTLASS_DEVICE\n  CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}\n\n  template<class ProblemSize, class ClusterShape, class TileShape>\n  static Params to_underlying_arguments(\n      ProblemSize const& problem_size, KernelHardwareInfo hw_info,\n      ClusterShape const& cluster_shape, TileShape const& tile_shape) {\n    using namespace cute;\n    // Get SM count if needed, otherwise use user supplied SM count\n    int sm_count = hw_info.sm_count;\n    if (sm_count <= 0) {\n      CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n          \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n      sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n    }\n\n    CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n    hw_info.sm_count = sm_count;\n\n    int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));\n    int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);\n\n    return Params {\n      num_blocks,\n      { size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) },\n      hw_info\n    };\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);\n    return grid;\n  }\n\n  CUTLASS_DEVICE\n  bool is_valid() {\n    return block_idx < params.num_blocks;\n  }\n\n  CUTLASS_DEVICE\n  auto get_block_coord() {\n    using namespace cute;\n    int block_decode = block_idx;\n    int m_block, bidb, bidh;\n    params.divmod_h(block_decode, bidh, block_decode);\n    params.divmod_m_block(block_decode, m_block, block_decode);\n    params.divmod_b(block_decode, bidb, block_decode);\n    return make_coord(m_block, _0{}, make_coord(bidh, bidb));\n  }\n\n  CUTLASS_DEVICE\n  CausalPersistentTileScheduler& operator++() {\n    block_idx += gridDim.x;\n    return *this;\n  }\n};\n////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cute/layout.hpp\"\n#include <kerutils/kerutils.cuh> // for  KERUTILS_ENABLE_SM100A\n\nnamespace cutlass::fmha::kernel {\n\nusing namespace cute;\n\ntemplate<class ProblemShape, class Element, class ElementAcc>\nstruct FmhaKernelBwdConvert {\n\n  struct Arguments {\n    ProblemShape problem_shape;\n\n    const ElementAcc* ptr_src_dQ;\n    tuple<int, _1, tuple<int, int>> stride_src_dQ;\n    const ElementAcc* ptr_src_dK;\n    tuple<int, _1, tuple<int, int>> stride_src_dK;\n    const ElementAcc* ptr_src_dV;\n    tuple<int, _1, tuple<int, int>> stride_src_dV;\n\n    Element* ptr_dest_dQ;\n    tuple<int, _1, tuple<int, int>> stride_dest_dQ;\n    Element* ptr_dest_dK;\n    tuple<int, _1, tuple<int, int>> stride_dest_dK;\n    Element* ptr_dest_dV;\n    tuple<int, _1, tuple<int, int>> stride_dest_dV;\n\n    ElementAcc scale = 1.0;\n  };\n\n  using Params = Arguments;\n\n  using ClusterShape = Shape<_1, _1, _1>;\n  static constexpr int SharedStorageSize = 0;\n\n  static const int MinBlocksPerMultiprocessor = 1;\n  static const int MaxThreadsPerBlock = 128;\n  using ArchTag = cutlass::arch::Sm90;\n\n  static const int kBlockSeq = 8;\n\n  static size_t get_workspace_size(Arguments const& args) { return 0; }\n  static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {\n    return cutlass::Status::kSuccess;\n  }\n\n  static const int kNumThreadsD = 16;\n  static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;\n  static const int kElementsPerLoad = 4;\n\n  static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;\n\n  static bool can_implement(Arguments const& args) {\n    return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));\n    return grid;\n  }\n\n  static dim3 get_block_shape() {\n    dim3 block(kNumThreadsD, kNumThreadsSeq, 1);\n    return block;\n  }\n\n  static Params to_underlying_arguments(Arguments const& args, void* workspace) {\n    return args;\n  }\n\n  template<class StrideSrc, class StrideDest, class Count>\n  CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {\n    auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;\n    auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;\n\n    int seqlen = count;\n    if constexpr (is_variable_length_v<decltype(count)>) {\n      int offset = count.cumulative_length[blockIdx.y];\n      ptr_dest_bh += offset * get<0>(stride_dest);\n      seqlen = count.cumulative_length[blockIdx.y + 1] - offset;\n    }\n\n    for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {\n      int idx_s = idx_s_t + kBlockSeq * blockIdx.z;\n      if (idx_s >= seqlen) continue;\n      auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);\n      auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);\n\n      for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) {\n        ElementAcc value_src[kElementsPerLoad];\n        Element value_dest[kElementsPerLoad];\n\n        using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;\n        using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;\n        *reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);\n\n        for (int v = 0; v < kElementsPerLoad; v++) {\n          value_dest[v] = static_cast<Element>(params.scale * value_src[v]);\n        }\n\n        *reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);\n      }\n    }\n  }\n\n  CUTLASS_DEVICE void operator()(const Params &params, char* smem) {\n#if defined(KERUTILS_ENABLE_SM100A)\n    if (params.ptr_src_dQ != nullptr) {\n      copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));\n    }\n    if (params.ptr_src_dK != nullptr) {\n      copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape));\n    }\n    if (params.ptr_src_dV != nullptr) {\n      copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\\n\");\n    }\n#endif\n  }\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cute/layout.hpp\"\n#include <kerutils/kerutils.cuh> // for  KERUTILS_ENABLE_SM100A\n\nnamespace cutlass::fmha::kernel {\n\nusing namespace cute;\n\ntemplate<class ProblemShape, class Element, class ElementAcc>\nstruct FmhaKernelBwdSumOdO {\n\n  struct Arguments {\n    ProblemShape problem_shape;\n\n    const Element* ptr_O;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;\n    const Element* ptr_dO;\n    cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;\n\n    ElementAcc* ptr_sum_OdO;\n    cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;\n\n    const ElementAcc* ptr_lse = nullptr;\n    cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;\n\n    ElementAcc* ptr_scaled_lse = nullptr;\n    cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;\n\n    ElementAcc sum_odo_scale = 1.0;\n    ElementAcc lse_scale = 1.0;\n  };\n\n  using Params = Arguments;\n\n  using ClusterShape = Shape<_1, _1, _1>;\n  static constexpr int SharedStorageSize = 0;\n\n  static const int MinBlocksPerMultiprocessor = 1;\n  static const int MaxThreadsPerBlock = 128;\n  using ArchTag = cutlass::arch::Sm100;\n\n  static size_t get_workspace_size(Arguments const& args) { return 0; }\n  static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {\n    return cutlass::Status::kSuccess;\n  }\n\n  static const int kBlockQ = 16;\n\n  static const int kNumThreadsD = 8;\n  static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;\n  static const int kElementsPerLoad = 2;\n\n  static const int kIterationsQ = kBlockQ / kNumThreadsQ;\n\n  static bool can_implement(Arguments const& args) {\n    return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape));\n    return grid;\n  }\n\n  static dim3 get_block_shape() {\n    dim3 block(kNumThreadsD, kNumThreadsQ, 1);\n    return block;\n  }\n\n  static Params to_underlying_arguments(Arguments const& args, void* workspace) {\n    return args;\n  }\n\n  CUTLASS_DEVICE void operator()(const Params &params, char* smem) {\n#if defined(KERUTILS_ENABLE_SM100A)\n    auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);\n    auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);\n    auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);\n    auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);\n    auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);\n\n    auto problem_q = get<0>(params.problem_shape);\n    int seqlen_q = problem_q;\n    if constexpr (is_variable_length_v<decltype(problem_q)>) {\n      int offset = problem_q.cumulative_length[blockIdx.z];\n      ptr_O_bh += offset * get<0>(params.stride_O);\n      ptr_dO_bh += offset * get<0>(params.stride_dO);\n      ptr_lse_bh += offset * get<0>(params.stride_lse);\n      seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;\n    }\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {\n      int idx_q = idx_q_t + kBlockQ * blockIdx.x;\n      if (idx_q >= seqlen_q) continue;\n      ElementAcc acc = 0;\n      auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);\n      auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);\n      auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);\n      auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);\n      auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);\n\n      for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {\n        Element value_O[kElementsPerLoad];\n        Element value_dO[kElementsPerLoad];\n\n        using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;\n        *reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);\n        *reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);\n\n        for (int v = 0; v < kElementsPerLoad; v++) {\n          acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]);\n        }\n      }\n\n      for (int i = 1; i < kNumThreadsD; i *= 2) {\n        acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);\n      }\n\n      if (threadIdx.x == 0) {\n        *ptr_sum_OdO_bhq = params.sum_odo_scale * acc;\n        if (params.ptr_scaled_lse) {\n          *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;\n        }\n      }\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\\n\");\n    }\n#endif\n  }\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/fmha_options.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n\n#include \"cutlass/cutlass.h\"\n\nnamespace cutlass::fmha::kernel {\n\ntemplate<auto kTag, typename Default, typename... Options>\nstruct find_option;\n\ntemplate<auto kTag, typename Default>\nstruct find_option<kTag, Default> {\n  using option_value = Default;\n};\n\ntemplate<auto kTag, typename Default, typename Option, typename... Options>\nstruct find_option<kTag, Default, Option, Options...> :\n  std::conditional_t<\n    Option::tag == kTag,\n    Option,\n    find_option<kTag, Default, Options...>\n  >\n{};\n\ntemplate<auto kTag, typename Default, typename... Options>\nusing find_option_t = typename find_option<kTag, Default, Options...>::option_value;\n\nenum class Tag {\n  kIsPersistent,\n  kNumMmaWarpGroups,\n  kLoadsQSeparately,\n\n  kIsMainloopLocked,\n  kIsEpilogueLocked,\n\n  kStagesQ,\n  kStagesKV,\n\n  kEpilogueKind,\n\n  kBlocksPerSM,\n  kClusterM,\n\n  kAccQK\n};\n\ntemplate<auto kTag, class Value>\nstruct Option {\n  static constexpr auto tag = kTag;\n  using option_value = Value;\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/fast_math.h\"\n#include \"cutlass/kernel_hardware_info.h\"\n\nnamespace cutlass::fmha::kernel {\n\n////////////////////////////////////////////////////////////////////////////////\n\nstruct IndividualTileScheduler {\n\n  struct Params {\n    dim3 grid;\n  };\n\n  bool valid_ = true;\n\n  CUTLASS_DEVICE\n  IndividualTileScheduler(Params const&) {}\n\n  template<class ProblemSize, class ClusterShape, class TileShape>\n  static Params to_underlying_arguments(\n      ProblemSize const& problem_size, KernelHardwareInfo hw_info,\n      ClusterShape const& cluster_shape, TileShape const& tile_shape) {\n    using namespace cute;\n    dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size));\n    return Params{ grid };\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    return params.grid;\n  }\n\n  CUTLASS_DEVICE\n  bool is_valid() {\n    return valid_;\n  }\n\n  CUTLASS_DEVICE\n  auto get_block_coord() {\n    using namespace cute;\n    return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));\n  }\n\n  CUTLASS_DEVICE\n  IndividualTileScheduler& operator++() {\n    valid_ = false;\n    return *this;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////\n\nstruct PersistentTileScheduler {\n\n  struct Params {\n    int num_blocks;\n    FastDivmod divmod_m_block;\n    FastDivmod divmod_h;\n    FastDivmod divmod_b;\n\n    KernelHardwareInfo hw_info;\n  };\n\n  int block_idx = 0;\n  Params params;\n\n  CUTLASS_DEVICE\n  PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}\n\n  template<class ProblemSize, class ClusterShape, class TileShape>\n  static Params to_underlying_arguments(\n      ProblemSize const& problem_size, KernelHardwareInfo hw_info,\n      ClusterShape const& cluster_shape, TileShape const& tile_shape) {\n    using namespace cute;\n    // Get SM count if needed, otherwise use user supplied SM count\n    int sm_count = hw_info.sm_count;\n    if (sm_count <= 0) {\n      CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n          \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n      sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);\n    }\n\n    CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n    hw_info.sm_count = sm_count;\n\n    int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));\n    int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);\n\n    return Params {\n      num_blocks,\n      { max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },\n      hw_info\n    };\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);\n    return grid;\n  }\n\n  CUTLASS_DEVICE\n  bool is_valid() {\n    return block_idx < params.num_blocks;\n  }\n\n  CUTLASS_DEVICE\n  auto get_block_coord() {\n    using namespace cute;\n    int block_decode = block_idx;\n    int m_block, bidb, bidh;\n    params.divmod_m_block(block_decode, m_block, block_decode);\n    params.divmod_b(block_decode, bidb, block_decode);\n    params.divmod_h(block_decode, bidh, block_decode);\n    return make_coord(m_block, _0{}, make_coord(bidh, bidb));\n  }\n\n  CUTLASS_DEVICE\n  PersistentTileScheduler& operator++() {\n    block_idx += gridDim.x;\n    return *this;\n  }\n};\n\n\n////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025  - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n\n#include \"cute/tensor.hpp\"\n#include \"cute/arch/simd_sm100.hpp\"\n\n#include \"cutlass/arch/arch.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n\n#include <kerutils/kerutils.cuh> // for  KERUTILS_ENABLE_SM100A\n#include \"../collective/fmha_common.hpp\"\n\n#include <cmath>\n\nnamespace cutlass::fmha::kernel {\n\nusing namespace cutlass::fmha::collective;\n\nusing namespace cute;\n\ntemplate<\n    class ProblemShape,\n    class Element,\n    class ElementAcc,\n    class TileShape,\n    class Mask\n>\nstruct Sm100FmhaBwdKernelTmaWarpSpecialized {\n\n  using TileShapeQ = decltype(get<0>(TileShape{}));\n  static_assert(std::is_same_v<TileShapeQ, _128>, \"tile shape K must be 128\");\n  using TileShapeK = decltype(get<1>(TileShape{}));\n  static_assert(std::is_same_v<TileShapeK, _128>, \"tile shape K must be 128\");\n  using TileShapeDQK = decltype(get<2>(TileShape{}));\n  using TileShapeDVO = decltype(get<2>(TileShape{}));\n\n  using TmemAllocator = cute::TMEM::Allocator1Sm;\n  struct TmemAllocation {\n    static constexpr uint32_t kDK = 0;                     // TileShapeK x TileShapeDQK x acc\n    static constexpr uint32_t kDV = kDK + TileShapeDQK{};  // TileShapeK x TileShapeDVO x acc\n    static constexpr uint32_t kDQ = kDV + TileShapeDVO{};  // TileShapeQ x TileShapeDQK x acc\n    static constexpr uint32_t kDP = kDQ;                   // TileShapeK x TileShapeQ   x inp\n    static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{});\n    static constexpr uint32_t kP = kS;\n    static constexpr uint32_t kTotal = kS + TileShapeQ{};\n  };\n\n  static_assert(\n      static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,\n      \"using too much tmem\"\n  );\n\n  enum class WarpRole {\n    Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4\n  };\n\n  static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;\n  static constexpr int kNumComputeWarps = 8;\n  static constexpr int kNumReduceWarps = 4;\n  CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {\n    return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);\n  }\n\n  struct RegisterAllocation {\n    static constexpr int kWarpgroup0 = 160-8;\n    static constexpr int kWarpgroup1 = 128;\n    static constexpr int kWarpgroup2 = 96;\n    static constexpr int kReduce = kWarpgroup0;\n    static constexpr int kCompute = kWarpgroup1;\n    static constexpr int kMma = kWarpgroup2;\n    static constexpr int kEmpty = kWarpgroup2;\n    static constexpr int kLoad = kWarpgroup2;\n\n    static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);\n  };\n\n  using ArchTag = cutlass::arch::Sm100;\n\n  using ClusterShape = Shape<_1, _1, _1>;\n  using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;\n\n  static constexpr int MinBlocksPerMultiprocessor = 1;\n  static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;\n  static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;\n\n  static constexpr int Alignment = 128 / sizeof_bits_v<Element>;\n  static constexpr int kStages = 2;\n\n  using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;\n  using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;\n\n  // compute S\n  using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousK, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeQ, TileShapeDQK>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeKQ = typename CollectiveMmaKQ::TileShape;\n  using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma;\n\n  // compute dP\n  using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousK, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeQ, TileShapeDVO>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeVDO = typename CollectiveMmaVDO::TileShape;\n  using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma;\n\n  // compute dV\n  using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // needs to match ordering of S calculation\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeDVO, TileShapeQ>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapePDO = typename CollectiveMmaPDO::TileShape;\n  using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{}));\n\n  // compute dK\n  using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // somewhat arbitrary since we dump to smem, need to agree with the next one\n      Element, TensorStrideContiguousK , Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeDQK, TileShapeQ>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;\n  using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;\n\n  // compute dQ\n  using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // somewhat arbitrary since we dump to smem, need to agree with the previous one\n      Element, TensorStrideContiguousMN, Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeQ, TileShapeDQK, TileShapeK>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeDSK = typename CollectiveMmaDSK::TileShape;\n  using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;\n\n  // pipelines are named Pipeline<Producer><Consumer><Resource>\n  static constexpr int kStagesComputeSmem = 1;\n  using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;\n  using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;\n  using PipelineLoadComputeLSE = PipelineAsync<1>;\n  using PipelineLoadComputeSumOdO = PipelineAsync<1>;\n  using PipelineMmaComputeS = PipelineUmmaAsync<1>;\n  using PipelineMmaComputeDP = PipelineUmmaAsync<1>;\n  using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;\n  using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;\n  using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;\n  using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;\n  static constexpr int kStagesReduceTmaStore = 2;\n  using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;\n\n  struct PipelineStorage {\n    alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;\n    alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;\n    alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;\n    alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;\n    alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;\n    alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;\n    alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;\n    alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;\n    alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;\n    alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;\n  };\n\n  template<class Layout, class Stages = _1>\n  static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {\n    return composition(layout, make_tuple(_, _, _, make_layout(stages)));\n  }\n\n  using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{}));\n  using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{}));\n  using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{}));\n  using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{}));\n  using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));\n  using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;\n  using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;\n\n  using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));\n  using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));\n  using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));\n  using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));\n\n  using TileShapeDQ = _32;\n  using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<\n      cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ\n  >());\n  using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;\n  using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));\n\n  struct TensorStorage {\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;\n    };\n    alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;\n    };\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;\n    };\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;\n    };\n    alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;\n    alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;\n    alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;\n  };\n\n  static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);\n\n  static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);\n\n  struct SharedStorage {\n    TensorStorage tensors;\n    PipelineStorage pipelines;\n    uint32_t tmem_base_ptr;\n  };\n\n  // this is tight enough that it won't work with sizeof due to padding for alignment\n  static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);\n  static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, \"using too much smem\");\n\n  using TensorStride = TensorStrideContiguousK;  // S D (H B)\n  using RowTensorStride = Stride<_1, Stride<int, int>>;    // S (H B)\n\n  struct MainloopArguments {\n    const Element* ptr_q;\n    TensorStride stride_q;\n    const Element* ptr_k;\n    TensorStride stride_k;\n    const Element* ptr_v;\n    TensorStride stride_v;\n    const Element* ptr_do;\n    TensorStride stride_do;\n\n    const ElementAcc* ptr_lse;\n    RowTensorStride stride_lse;\n\n    const ElementAcc* ptr_sum_odo;\n    RowTensorStride stride_sum_odo;\n\n    ElementAcc* ptr_dq_acc;\n    TensorStride stride_dq_acc;\n\n    ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});\n  };\n\n  using TMA_K = typename CollectiveMmaKQ::Params::TMA_A;\n  using TMA_V = typename CollectiveMmaVDO::Params::TMA_A;\n  using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B;\n  using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;\n\n  using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},\n      make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),\n      SmemLayoutDQ{}(_, _, _0{})\n  ));\n\n  struct MainloopParams {\n    TMA_K tma_load_k;\n    TMA_V tma_load_v;\n    TMA_Q tma_load_q;\n    TMA_DO tma_load_do;\n    TMA_DQ tma_red_dq;\n  };\n\n  struct EpilogueArguments {\n    Element* ptr_dk;\n    TensorStride stride_dk;\n    Element* ptr_dv;\n    TensorStride stride_dv;\n  };\n\n  struct Arguments {\n    ProblemShape problem_shape;\n    MainloopArguments mainloop;\n    EpilogueArguments epilogue;\n    KernelHardwareInfo hw_info;\n  };\n\n  struct Params {\n    ProblemShape problem_shape;\n    MainloopArguments mainloop;\n    MainloopParams mainloop_params;\n    EpilogueArguments epilogue;\n    KernelHardwareInfo hw_info;\n  };\n\n\n  static bool can_implement(Arguments const& args) {\n    auto [Q, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {\n      return false;\n    }\n    if (D % Alignment != 0 || D_VO % Alignment != 0) {\n      return false;\n    }\n    return true;\n  }\n\n\n  static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {\n    return Status::kSuccess;\n  }\n\n\n  static Params to_underlying_arguments(Arguments const& args, void*) {\n    auto [Q_, K_, D, D_VO, HB] = args.problem_shape;\n    int Q = Q_;\n    int K = K_;\n\n    if constexpr (is_variable_length_v<decltype(Q_)>) {\n      Q = Q_.total_length;\n    }\n    if constexpr (is_variable_length_v<decltype(K_)>) {\n      K = K_.total_length;\n    }\n\n    auto params_kq = CollectiveMmaKQ::to_underlying_arguments(\n      make_shape(K, Q, D, HB),\n      typename CollectiveMmaKQ::Arguments {\n        args.mainloop.ptr_k, args.mainloop.stride_k,\n        args.mainloop.ptr_q, args.mainloop.stride_q,\n      }, /*workspace=*/nullptr);\n\n    auto params_vdo = CollectiveMmaVDO::to_underlying_arguments(\n      make_shape(K, Q, D_VO, HB),\n      typename CollectiveMmaVDO::Arguments {\n        args.mainloop.ptr_v, args.mainloop.stride_v,\n        args.mainloop.ptr_do, args.mainloop.stride_do,\n      }, /*workspace=*/nullptr);\n\n    TMA_DQ tma_red_dq = make_tma_copy(\n        SM90_TMA_REDUCE_ADD{},\n        make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),\n        SmemLayoutDQ{}(_, _, _0{})\n    );\n\n    return Params{\n      args.problem_shape,\n      args.mainloop,\n      MainloopParams{\n        params_kq.tma_load_a,\n        params_vdo.tma_load_a,\n        params_kq.tma_load_b,\n        params_vdo.tma_load_b,\n        tma_red_dq\n      },\n      args.epilogue,\n      args.hw_info\n    };\n  }\n\n\n  template<class T>\n  static CUTLASS_DEVICE auto quantize(T const& input) {\n    constexpr int AlignmentS = 4;\n    auto output = make_tensor<Element>(shape(input));\n    auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);\n    auto output_vec = recast<Array<Element, AlignmentS>>(output);\n\n    cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(input_vec); i++) {\n      output_vec(i) = epilogue_op(input_vec(i));\n    }\n\n    return output;\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void load(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      MainloopParams const& mainloop_params,\n      TensorStorage& shared_tensors,\n      PipelineLoadMmaQ& pipeline_load_mma_q,\n      typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,\n      PipelineLoadMmaDO& pipeline_load_mma_do,\n      typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,\n      PipelineLoadComputeLSE& pipeline_load_compute_lse,\n      typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,\n      PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,\n      typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    using X = Underscore;\n\n    uint16_t mcast_mask = 0;\n\n    auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));\n    auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));\n    auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));\n    auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));\n\n    auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);\n    auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);\n    auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);\n    auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);\n\n    auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});\n    auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});\n    auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{});\n    auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step<X, _1, _1>{});\n\n    ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});\n    ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});\n\n    auto tSTgK = cta_mma_kq.partition_A(gK);\n    auto tSTgQ = cta_mma_kq.partition_B(gQ);\n    auto tDPTgV = cta_mma_vdo.partition_A(gV);\n    auto tDPTgDO = cta_mma_vdo.partition_B(gDO);\n\n    auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});\n    auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});\n    auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});\n    auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});\n\n    auto [tKgK_mkl, tKsK] = tma_partition(\n        mainloop_params.tma_load_k, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sK), group_modes<0,3>(tSTgK));\n    auto [tQgQ_mkl, tQsQ] = tma_partition(\n        mainloop_params.tma_load_q, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));\n    auto [tVgV_mkl, tVsV] = tma_partition(\n        mainloop_params.tma_load_v, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));\n    auto [tDOgDO_mkl, tDOsDO] = tma_partition(\n        mainloop_params.tma_load_do, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));\n\n    // set up lse and sum_odo\n\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);\n    auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);\n\n    pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);\n\n    // load K\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),\n          tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),\n          tKsK(_, _0{})\n      );\n    }\n\n    // load Q\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),\n          tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),\n          tQsQ(_, pipeline_load_mma_q_producer_state.index())\n      );\n    }\n\n    ++pipeline_load_mma_q_producer_state;\n\n    pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);\n\n    // load LSE\n    // 32 threads loading 128 values of 32b each\n    // so 4*32b=128b\n\n    int thread_idx = threadIdx.x % NumThreadsPerWarp;\n    int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;\n    int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;\n    auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);\n    for (int i = 0; i < 4; i++) {\n      cutlass::arch::cp_async_zfill<4>(\n          shared_tensors.smem_lse.begin() + smem_idx + i,\n          &mLSE(gmem_idx + i, blk_coord_batch),\n          gmem_idx + i < Q\n      );\n    }\n\n    pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);\n    ++pipeline_load_compute_lse_producer_state;\n\n\n    pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);\n    tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);\n\n    pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);\n\n    // load V\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),\n          tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),\n          tVsV(_, _0{})\n      );\n    }\n\n    // load dO\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),\n          tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),\n          tDOsDO(_, pipeline_load_mma_do_producer_state.index())\n      );\n    }\n\n    ++pipeline_load_mma_do_producer_state;\n\n    pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);\n\n    // load sum_OdO\n    smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;\n    gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;\n    auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);\n    for (int i = 0; i < 4; i++) {\n      cutlass::arch::cp_async_zfill<4>(\n          shared_tensors.smem_sum_odo.begin() + smem_idx + i,\n          &mSumOdO(gmem_idx + i, blk_coord_batch),\n          gmem_idx + i < Q\n      );\n    }\n\n    pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);\n    ++pipeline_load_compute_sum_odo_producer_state;\n\n    iter_count -= 1;\n    iter_index += 1;\n\n    while (iter_count > 0) {\n      pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);\n      tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);\n\n      // load Q\n      if (cute::elect_one_sync()) {\n        cute::copy(\n            mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),\n            tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),\n            tQsQ(_, pipeline_load_mma_q_producer_state.index())\n        );\n      }\n\n      ++pipeline_load_mma_q_producer_state;\n\n      pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);\n\n      // load LSE\n      smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;\n      gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;\n      for (int i = 0; i < 4; i++) {\n        cutlass::arch::cp_async_zfill<4>(\n            shared_tensors.smem_lse.begin() + smem_idx + i,\n            &mLSE(gmem_idx + i, blk_coord_batch),\n            gmem_idx + i < Q\n        );\n      }\n\n      pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);\n      ++pipeline_load_compute_lse_producer_state;\n\n      pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);\n      tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);\n\n      // load dO\n      if (cute::elect_one_sync()) {\n        cute::copy(\n            mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),\n            tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),\n            tDOsDO(_, pipeline_load_mma_do_producer_state.index())\n        );\n      }\n\n      ++pipeline_load_mma_do_producer_state;\n\n      pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);\n\n      // load sum_OdO\n      smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;\n      gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;\n      for (int i = 0; i < 4; i++) {\n        cutlass::arch::cp_async_zfill<4>(\n            shared_tensors.smem_sum_odo.begin() + smem_idx + i,\n            &mSumOdO(gmem_idx + i, blk_coord_batch),\n            gmem_idx + i < Q\n        );\n      }\n\n      pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);\n      ++pipeline_load_compute_sum_odo_producer_state;\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n  }\n\n\n  template<class BlkCoord, class ProblemShape_>\n  CUTLASS_DEVICE void mma(\n      BlkCoord const& blk_coord,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      TensorStorage& shared_tensors,\n      PipelineLoadMmaQ& pipeline_load_mma_q,\n      typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,\n      PipelineLoadMmaDO& pipeline_load_mma_do,\n      typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,\n      PipelineMmaComputeS& pipeline_mma_compute_s,\n      typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,\n      PipelineMmaComputeDP& pipeline_mma_compute_dp,\n      typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,\n      PipelineMmaReduceDQ& pipeline_mma_reduce_dq,\n      typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,\n      PipelineComputeMmaP& pipeline_compute_mma_p,\n      typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,\n      PipelineComputeMmaDS& pipeline_compute_mma_ds,\n      typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});\n    auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});\n    auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});\n    auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});\n\n    auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});\n    auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});\n    auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});\n    auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});\n    auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});\n    auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});\n\n    Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK);\n    Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ);\n\n    Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV);\n    Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO);\n\n    Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);\n    Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);\n\n    Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);\n    Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);\n\n    Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});\n    tDVrP.data() = TmemAllocation::kP;\n    Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);\n\n    TiledMmaKQ tiled_mma_kq;\n    TiledMmaVDO tiled_mma_vdo;\n    TiledMmaDSK tiled_mma_dsk;\n    TiledMmaDSQ tiled_mma_dsq;\n    TiledMmaPDO tiled_mma_pdo;\n\n    tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;\n    tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;\n\n    Tensor tSTtST =  partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{}));\n    tSTtST.data() = TmemAllocation::kS;\n\n    Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{}));\n    tDPTtDPT.data() = TmemAllocation::kDP;\n\n    Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));\n    tDQtDQ.data() = TmemAllocation::kDQ;\n\n    Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));\n    tDKtDK.data() = TmemAllocation::kDK;\n\n    Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));\n    tDVtDV.data() = TmemAllocation::kDV;\n\n    auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;\n\n    pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);\n    pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);\n\n    // S = Q*K\n    tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {\n      cute::gemm(tiled_mma_kq,\n                 tSTrK(_,_,k_block,_0{}),\n                 tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),\n                 tSTtST);\n      tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    ++pipeline_load_mma_q_consumer_state;\n\n    pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);\n    ++pipeline_mma_compute_s_producer_state;\n\n    pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);\n\n    pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);\n    pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);\n\n    // dP = dO*V\n    tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {\n      cute::gemm(tiled_mma_vdo,\n                 tDPTrV(_,_,k_block,_0{}),\n                 tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                 tDPTtDPT);\n      tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);\n    ++pipeline_mma_compute_dp_producer_state;\n\n    pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);\n\n    // dV = P*dO\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {\n      cute::gemm(tiled_mma_pdo,\n                 tDVrP(_,_,k_block),\n                 tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                 tDVtDV);\n      tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);\n    ++pipeline_compute_mma_p_consumer_state;\n\n    pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);\n    ++pipeline_load_mma_do_consumer_state;\n\n    iter_count -= 1;\n\n    // in tmem, S & P overlap\n    // and dP and dQ overlap\n    // so we need to acquire dQ and dP at the same time\n    while (iter_count > 0) {\n      pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);\n      pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);\n\n      // S = Q*K\n      tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {\n        cute::gemm(tiled_mma_kq,\n                   tSTrK(_,_,k_block,_0{}),\n                   tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),\n                   tSTtST);\n        tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      ++pipeline_load_mma_q_consumer_state;\n\n      pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);\n      ++pipeline_mma_compute_s_producer_state;\n\n      pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);\n\n      // we need to acquire dP here, because tmem dQ == tmem dP\n      pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);\n\n      // dQ = dS*K\n      tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {\n        cute::gemm(tiled_mma_dsk,\n                   tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                   tDQrKT(_,_,k_block,_0{}),\n                   tDQtDQ);\n        tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);\n      ++pipeline_mma_reduce_dq_producer_state;\n\n      // dK = dS*Q\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {\n        cute::gemm(tiled_mma_dsq,\n                   tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                   tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),\n                   tDKtDK);\n        tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);\n      ++pipeline_load_mma_q_release_state;\n\n      pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);\n      ++pipeline_compute_mma_ds_consumer_state;\n\n      // we grab dq here, because in tmem dq == dp\n      pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);\n\n      pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);\n\n      // dP = dO*V\n      tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {\n        cute::gemm(tiled_mma_vdo,\n                   tDPTrV(_,_,k_block,_0{}),\n                   tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                   tDPTtDPT);\n        tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);\n      ++pipeline_mma_compute_dp_producer_state;\n\n      pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);\n\n      // dV = P*dO\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {\n        cute::gemm(tiled_mma_pdo,\n                   tDVrP(_,_,k_block),\n                   tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                   tDVtDV);\n        tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);\n      ++pipeline_compute_mma_p_consumer_state;\n\n      pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);\n      ++pipeline_load_mma_do_consumer_state;\n\n      iter_count -= 1;\n    }\n\n    // signal to the epilogue that dV is ready\n    pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);\n    pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);\n    ++pipeline_mma_compute_dkdv_producer_state;\n\n    pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);\n\n    pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);\n\n    // dK = dS*Q\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {\n      cute::gemm(tiled_mma_dsq,\n                 tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                 tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),\n                 tDKtDK);\n      tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    // signal to epilgue that dK is ready\n    pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);\n    ++pipeline_mma_compute_dkdv_producer_state;\n\n    // we've already acquired mma_reduce_dq in the loop\n\n    // dQ = dS*K\n    tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {\n      cute::gemm(tiled_mma_dsk,\n                 tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                 tDQrKT(_,_,k_block,_0{}),\n                 tDQtDQ);\n      tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);\n    ++pipeline_mma_reduce_dq_producer_state;\n\n    pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);\n    ++pipeline_load_mma_q_release_state;\n\n    pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);\n    ++pipeline_compute_mma_ds_consumer_state;\n  }\n\n\n\n  template<class TensorG, class TensorR, class TensorC, class TensorShape>\n  CUTLASS_DEVICE void store(\n      TensorG gmem,\n      TensorR const& regs,\n      TensorC const& coord,\n      TensorShape const& tensor_shape) {\n\n    Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });\n\n    auto copy_op = make_cotiled_copy(\n        Copy_Atom<UniversalCopy<uint128_t>, Element>{},\n        make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),\n        regs.layout()\n    );\n    auto thr_copy = copy_op.get_slice(_0{});\n\n    Tensor quantized_regs = quantize(regs);\n    Tensor tCr = thr_copy.partition_S(quantized_regs);\n    Tensor tCg = thr_copy.partition_D(gmem);\n    Tensor tPc = thr_copy.partition_D(preds);\n\n    copy_if(copy_op, tPc, tCr, tCg);\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void epilogue_clear(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);\n    auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);\n    auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDK = domain_offset(\n        make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapeDSQ{}))\n    );\n\n    auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);\n    auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);\n    auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDV = domain_offset(\n        make_coord(blk_coord_k * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapePDO{}))\n    );\n\n    for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {\n      if (elem_less(cDK(i), select<1,2>(problem_shape))) {\n        gDK(i) = Element(0);\n      }\n    }\n    for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {\n      if (elem_less(cDV(i), select<1,3>(problem_shape))) {\n        gDV(i) = Element(0);\n      }\n    }\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void epilogue(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    auto load_op = SM100_TMEM_LOAD_32dp32b16x{};\n\n    auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});\n    tDKtDK.data() = TmemAllocation::kDK;\n\n    auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);\n    auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);\n    auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDK = domain_offset(\n        make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapeDSQ{}))\n    );\n\n    constexpr int kNumWarpgroups = kNumComputeWarps / 4;\n    int dp_idx = threadIdx.x % 128;\n    int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;\n\n    auto split_wg = [&](auto const& t) {\n      if constexpr (decltype(rank(t))::value == 3) {\n        auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));\n        return p(_, _, make_coord(wg_idx, _));\n      }\n      else {\n        auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));\n        return p(_, _, _, make_coord(wg_idx, _));\n      }\n    };\n\n    auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);\n    auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);\n\n    Tensor tTR_cDK   = split_wg(thread_t2r_dk.partition_D(cDK));\n    Tensor tTR_gDK   = split_wg(thread_t2r_dk.partition_D(gDK));\n    Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));\n    Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));\n\n    auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});\n    tDVtDV.data() = TmemAllocation::kDV;\n\n    auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);\n    auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);\n    auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDV = domain_offset(\n        make_coord(blk_coord_k * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapePDO{}))\n    );\n\n    auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);\n    auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);\n\n    Tensor tTR_cDV   = split_wg(thread_t2r_dv.partition_D(cDV));\n    Tensor tTR_gDV   = split_wg(thread_t2r_dv.partition_D(gDV));\n    Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));\n    Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));\n\n    pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);\n\n    // load tDVtDV\n    cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);\n\n    // store tDVgDV\n    store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));\n\n    cutlass::arch::fence_view_async_tmem_load();\n    pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);\n    ++pipeline_mma_compute_dkdv_consumer_state;\n\n    pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);\n\n    // load tDKtDK\n    cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTR_rDK); i++) {\n      tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);\n    }\n\n    // store tDKgDK\n    store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));\n\n    cutlass::arch::fence_view_async_tmem_load();\n    pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);\n    ++pipeline_mma_compute_dkdv_consumer_state;\n\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void compute(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args,\n      TensorStorage& shared_tensors,\n      PipelineLoadComputeLSE& pipeline_load_compute_lse,\n      typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,\n      PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,\n      typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,\n      PipelineMmaComputeS& pipeline_mma_compute_s,\n      typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,\n      PipelineMmaComputeDP& pipeline_mma_compute_dp,\n      typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,\n      PipelineComputeMmaP& pipeline_compute_mma_p,\n      typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,\n      PipelineComputeMmaDS& pipeline_compute_mma_ds,\n      typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {\n\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    // in tmem, S & P overlap\n    // and dP and dQ overlap\n\n    // there are two compute wg's that cooperatively compute softmax\n    // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc\n\n    auto load_op = SM100_TMEM_LOAD_32dp32b16x{};\n    auto store_op = []() {\n      if constexpr (sizeof(Element) == 1) {\n        return SM100_TMEM_STORE_32dp32b4x{};\n      }\n      else {\n        return SM100_TMEM_STORE_32dp32b8x{};\n      }\n    }();\n\n    Tensor tSTtST =  partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{});\n    tSTtST.data() = TmemAllocation::kS;\n\n    Tensor tDPTtDPT =  partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{});\n    tDPTtDPT.data() = TmemAllocation::kDP;\n\n    Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{}));\n    Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{}));\n\n    constexpr int kNumWarpgroups = kNumComputeWarps / 4;\n    int dp_idx = threadIdx.x % 128;\n    int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;\n    auto tiled_t2r = make_tmem_copy(load_op, tSTtST);\n    auto thread_t2r = tiled_t2r.get_slice(dp_idx);\n\n    auto split_wg = [&](auto const& t) {\n      if constexpr (decltype(size<1>(t))::value > 1) {\n        if constexpr (decltype(rank(t))::value == 3) {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));\n          return p(_, make_coord(wg_idx, _), _);\n        }\n        else {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));\n          return p(_, make_coord(wg_idx, _), _, _);\n        }\n      }\n      else {\n        if constexpr (decltype(rank(t))::value == 3) {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));\n          return p(_, _, make_coord(wg_idx, _));\n        }\n        else {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));\n          return p(_, _, _, make_coord(wg_idx, _));\n        }\n\n      }\n    };\n\n\n    Tensor tTR_cST_p = thread_t2r.partition_D(cST);\n    Tensor tTR_cST   = split_wg(tTR_cST_p);\n    Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));\n    Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));\n\n    Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);\n    Tensor tTR_cDPT = split_wg(tTR_cDPT_p);\n    Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));\n    Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));\n\n    Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});\n    Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});\n\n    auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});\n\n    auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});\n    auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST);\n    tDVrP.data() = TmemAllocation::kP;\n\n    auto tiled_r2t = make_tmem_copy(store_op, tDVrP);\n    auto thread_r2t = tiled_r2t.get_slice(dp_idx);\n\n    auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));\n    auto tRT_cST_p = thread_r2t.partition_S(tDVcST);\n    auto tRT_cST = split_wg(tRT_cST_p);\n\n    bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);\n    int last_iter = iter_count - 1 + iter_index;\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    while (iter_count > 0) {\n      // wait for S and P\n      pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);\n      pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);\n      // wait for LSE\n      pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);\n\n      auto dispatch_bool = [](bool b, auto fn) {\n        if (b) {\n          fn(cute::true_type{});\n        }\n        else {\n          fn(cute::false_type{});\n        }\n      };\n\n      bool leading_causal_masking = false;\n      if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {\n        leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));\n      } else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {\n        int offset = get<1>(problem_shape) - get<0>(problem_shape);\n        int kv_left = get<1>(blk_coord) * TileShapeK{};\n        int kv_right = kv_left + TileShapeK{} - 1;\n        int q_left = iter_index * TileShapeQ{} + offset;\n        int q_right = q_left + TileShapeQ{} - 1;\n\n        leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));\n      }\n      bool trailing_residual_masking = false;\n      if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {\n        trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);\n      }\n\n      dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {\n\n        // compute P = softmax(S, LSE)\n        cute::copy(tiled_t2r, tTR_tST, tTR_rST);\n\n        if constexpr (decltype(is_masked_tile)::value) {\n          Mask{}.apply_mask(tTR_rST, [&](int i) {\n            auto c_transpose = tTR_cST(i);\n            return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});\n          }, problem_shape);\n        }\n\n        ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);\n        float2 softmax_scale_log2_e;\n        softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;\n        softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;\n\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < size(tTR_rST); i += 2) {\n          float2 acc;\n          float2 lse;\n          float2 out;\n          acc.x = tTR_rST(i);\n          acc.y = tTR_rST(i + 1);\n          lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());\n          lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());\n          cute::fma(out, softmax_scale_log2_e, acc, lse);\n          tTR_rST(i) = ::exp2f(out.x);\n          tTR_rST(i+1) = ::exp2f(out.y);\n        }\n\n        auto tRT_rST = quantize(tTR_rST);\n        auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));\n\n        cutlass::arch::fence_view_async_tmem_load();\n        cutlass::arch::NamedBarrier(\n          kNumComputeWarps * NumThreadsPerWarp,\n          cutlass::arch::ReservedNamedBarriers::TransformBarrier\n        ).arrive_and_wait();\n\n        cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);\n      });\n\n      // notify for P\n      cutlass::arch::fence_view_async_tmem_store();\n      pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);\n      ++pipeline_compute_mma_p_producer_state;\n      // release S\n      pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);\n      ++pipeline_mma_compute_s_consumer_state;\n      // release LSE\n      pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);\n      ++pipeline_load_compute_lse_consumer_state;\n\n      // wait for OdO\n      pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);\n      // wait for dP\n      pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);\n\n      // wait for dS\n      // in principle, we could defer waiting for dS, and move in the freeing of dP\n      // however, that would force us to keep dS in registers longer\n      pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);\n\n      // compute dS = dsoftmax(P, dP, sum_OdO)\n      cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(tTR_rDPT); i += 2) {\n        float2 st;\n        st.x = tTR_rST(i);\n        st.y = tTR_rST(i+1);\n        float2 dpt;\n        dpt.x = tTR_rDPT(i);\n        dpt.y = tTR_rDPT(i+1);\n        float2 odo;\n        odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());\n        odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());\n        float2 dif;\n        // sum odo is negated during preprocess\n        cute::add(dif, dpt, odo);\n        float2 out;\n        cute::mul(out, dif, st);\n        tTR_rDPT(i) = out.x;\n        tTR_rDPT(i+1) = out.y;\n      }\n\n      auto tTR_rDST = quantize(tTR_rDPT);\n\n      // release dP\n      cutlass::arch::fence_view_async_tmem_load();\n      pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);\n      ++pipeline_mma_compute_dp_consumer_state;\n\n      Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{})\n          (_, _, _, pipeline_compute_mma_ds_producer_state.index());\n\n      auto thread_layout = make_ordered_layout(\n          make_shape(_128{}, _128{}),\n          make_stride(_1{}, _0{})\n      );\n\n      auto sDS_pi = as_position_independent_swizzle_tensor(sDS);\n      auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p)));\n      auto sDS_pi_slice = split_wg(sDS_pi_slice_p);\n\n      copy_aligned(tTR_rDST, sDS_pi_slice);\n\n      // notify for dS\n      cutlass::arch::fence_view_async_shared();\n      pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);\n      ++pipeline_compute_mma_ds_producer_state;\n      // release OdO\n      pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);\n      ++pipeline_load_compute_sum_odo_consumer_state;\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n\n    epilogue(\n        blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,\n        pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state\n    );\n  }\n\n  template<class BlkCoord, class ProblemShape_>\n  CUTLASS_DEVICE void reduce(\n      BlkCoord const& blk_coord,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      MainloopParams const& mainloop_params,\n      TensorStorage& shared_tensors,\n      PipelineMmaReduceDQ& pipeline_mma_reduce_dq,\n      typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,\n      PipelineReduceTmaStore& pipeline_reduce_tma_store,\n      typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {\n\n    using X = Underscore;\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    // must match TileShapeDQ\n    auto load_op = SM100_TMEM_LOAD_32dp32b32x{};\n\n    auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});\n    tDQtDQ.data() = TmemAllocation::kDQ;\n\n    Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));\n    auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})\n        (_, _, _, _0{}, blk_coord_batch);\n\n    Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));\n\n    Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});\n\n    int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp);\n    auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);\n    auto thread_t2r = tiled_t2r.get_slice(thread_idx);\n\n    Tensor tTR_cDQ   = thread_t2r.partition_D(cDQ);\n    Tensor tTR_gDQ   = thread_t2r.partition_D(gDQ);\n    Tensor tTR_sDQ   = thread_t2r.partition_D(sDQ);\n    Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);\n\n    auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});\n\n    Tensor tDQsDQ = block_tma.partition_S(sDQ);\n    Tensor tDQcDQ = block_tma.partition_S(cDQ);\n    Tensor tDQgDQ = block_tma.partition_D(gDQ);\n\n    int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;\n\n    while (iter_count > 0) {\n      pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);\n\n      Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));\n\n      // load dQ from tmem to rmem\n      cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);\n\n      cutlass::arch::fence_view_async_tmem_load();\n      pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);\n      ++pipeline_mma_reduce_dq_consumer_state;\n\n      // we don't have enough smem to dump it all to smem, so we do it in stages\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size<2>(tTR_cDQ); i++) {\n        if (lane_predicate) {\n          pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);\n        }\n        // wait in all threads for the acquire to complete\n        cutlass::arch::NamedBarrier(\n            kNumReduceWarps * NumThreadsPerWarp,\n            cutlass::arch::ReservedNamedBarriers::TransposeBarrier\n        ).arrive_and_wait();\n\n        cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));\n\n        // wait for the stores to all be visible to the TMA\n        cutlass::arch::fence_view_async_shared();\n        cutlass::arch::NamedBarrier(\n            kNumReduceWarps * NumThreadsPerWarp,\n            cutlass::arch::ReservedNamedBarriers::TransposeBarrier\n        ).arrive_and_wait();\n        if (lane_predicate) {\n          // launch tma store\n          copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));\n          pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);\n        }\n\n        ++pipeline_reduce_tma_store_producer_state;\n      }\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n  }\n\n\n  CUTLASS_DEVICE void operator()(Params const& params, char* smem) {\n#if defined(KERUTILS_ENABLE_SM100A)\n    int warp_idx = cutlass::canonical_warp_idx_sync();\n    auto role = warp_idx_to_role(warp_idx);\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    if (role == WarpRole::Load && lane_predicate) {\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());\n    }\n\n    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);\n\n    int initializing_warp = 0;\n    typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Mma) {\n      pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;\n    }\n    pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    // Also loads K in the first iteration\n    pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;\n    pipeline_load_mma_q_params.initializing_warp = initializing_warp++;\n    PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Mma) {\n      pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;\n    }\n    pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    // Also loads V in the first iteration\n    pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;\n    pipeline_load_mma_do_params.initializing_warp = initializing_warp++;\n    PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;\n    }\n    pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;\n    pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;\n    pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;\n    PipelineLoadComputeLSE pipeline_load_compute_lse(\n      shared_storage.pipelines.load_compute_lse,\n      pipeline_load_compute_lse_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;\n    }\n    pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;\n    pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;\n    pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;\n    PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(\n      shared_storage.pipelines.load_compute_sum_odo,\n      pipeline_load_compute_sum_odo_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeS pipeline_mma_compute_s(\n      shared_storage.pipelines.mma_compute_s,\n      pipeline_mma_compute_s_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeDP pipeline_mma_compute_dp(\n      shared_storage.pipelines.mma_compute_dp,\n      pipeline_mma_compute_dp_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Reduce) {\n      pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;\n    }\n    pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;\n    PipelineMmaReduceDQ pipeline_mma_reduce_dq(\n      shared_storage.pipelines.mma_reduce_dq,\n      pipeline_mma_reduce_dq_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;\n    if (role == WarpRole::Mma) {\n      pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;\n    }\n    pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_compute_mma_p_params.consumer_arv_count = 1;\n    pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;\n    PipelineComputeMmaP pipeline_compute_mma_p(\n      shared_storage.pipelines.compute_mma_p,\n      pipeline_compute_mma_p_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;\n    if (role == WarpRole::Mma) {\n      pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;\n    }\n    pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_compute_mma_ds_params.consumer_arv_count = 1;\n    pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;\n    PipelineComputeMmaDS pipeline_compute_mma_ds(\n      shared_storage.pipelines.compute_mma_ds,\n      pipeline_compute_mma_ds_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(\n      shared_storage.pipelines.mma_compute_dkdv,\n      pipeline_mma_compute_dkdv_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n    PipelineReduceTmaStore pipeline_reduce_tma_store;\n\n    TmemAllocator tmem_allocator;\n\n    pipeline_init_arrive_relaxed(size(ClusterShape{}));\n\n    pipeline_load_mma_q.init_masks(ClusterShape{});\n    pipeline_load_mma_do.init_masks(ClusterShape{});\n    pipeline_mma_compute_s.init_masks(ClusterShape{});\n    pipeline_mma_compute_dp.init_masks(ClusterShape{});\n    pipeline_mma_reduce_dq.init_masks(ClusterShape{});\n    pipeline_compute_mma_p.init_masks(ClusterShape{});\n    pipeline_compute_mma_ds.init_masks(ClusterShape{});\n    pipeline_mma_compute_dkdv.init_masks(ClusterShape{});\n\n    typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;\n    typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;\n    typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;\n    typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;\n    typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;\n    typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;\n    typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;\n    typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;\n    typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;\n    typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;\n\n    auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();\n    auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();\n    auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();\n    auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();\n    auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();\n    auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();\n    auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();\n    auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();\n    auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();\n    auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();\n    auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();\n\n    pipeline_init_wait(size(ClusterShape{}));\n\n    auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));\n    auto [problem_shape, blk_offset] = apply_variable_length_offset(\n        params.problem_shape,\n        blk_coord\n    );\n    int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});\n    int iter_start = 0;\n    if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {\n      iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};\n    } else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {\n      int offset = get<1>(problem_shape) - get<0>(problem_shape);\n      iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});\n    }\n    if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {\n      return;\n    }\n    iter_count -= iter_start;\n\n    if (iter_count <= 0) {\n      epilogue_clear(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          params.mainloop,\n          params.epilogue\n      );\n      return;\n    }\n\n    if (role == WarpRole::Load) {\n      warpgroup_reg_set<RegisterAllocation::kLoad>();\n\n      load(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.mainloop_params,\n          shared_storage.tensors,\n          pipeline_load_mma_q, pipeline_load_mma_q_producer_state,\n          pipeline_load_mma_do, pipeline_load_mma_do_producer_state,\n          pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,\n          pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state\n      );\n\n    }\n    else if (role == WarpRole::Mma) {\n      warpgroup_reg_set<RegisterAllocation::kMma>();\n\n      tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);\n      __syncwarp();\n\n      mma(\n          blk_coord,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          shared_storage.tensors,\n          pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,\n          pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,\n          pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,\n          pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,\n          pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,\n          pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,\n          pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,\n          pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state\n      );\n\n    }\n    else if (role == WarpRole::Compute) {\n      warpgroup_reg_set<RegisterAllocation::kCompute>();\n\n      compute(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.epilogue,\n          shared_storage.tensors,\n          pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,\n          pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,\n          pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,\n          pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,\n          pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,\n          pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,\n          pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state\n      );\n\n      cutlass::arch::NamedBarrier(\n          kNumComputeWarps * NumThreadsPerWarp,\n          cutlass::arch::ReservedNamedBarriers::EpilogueBarrier\n      ).arrive_and_wait();\n\n      if (warp_idx % kNumComputeWarps == 0) {\n        uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;\n        tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);\n      }\n\n    }\n    else if (role == WarpRole::Reduce) {\n      warpgroup_reg_set<RegisterAllocation::kReduce>();\n\n      reduce(\n          blk_coord,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.mainloop_params,\n          shared_storage.tensors,\n          pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,\n          pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state\n      );\n\n      pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);\n    }\n    else {\n      warpgroup_reg_set<RegisterAllocation::kEmpty>();\n\n      /* no-op */\n\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\\n\");\n    }\n#endif\n  }\n\n  static dim3 get_block_shape() {\n    dim3 block(MaxThreadsPerBlock, 1, 1);\n    return block;\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    auto [Q, K, D, D_VO, HB] = params.problem_shape;\n    auto [H, B] = HB;\n    dim3 grid(ceil_div(K, TileShapeK{}), H, B);\n    return grid;\n  }\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2025  - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n\n#include \"cute/tensor.hpp\"\n#include \"cute/arch/simd_sm100.hpp\"\n\n#include \"cutlass/arch/arch.h\"\n#include \"cutlass/arch/memory_sm80.h\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n\n#include <kerutils/kerutils.cuh> // for  KERUTILS_ENABLE_SM100A\n#include \"../collective/fmha_common.hpp\"\n\n#include <cmath>\n\nnamespace cutlass::fmha::kernel {\n\nusing namespace cutlass::fmha::collective;\n\nusing namespace cute;\n\ntemplate<\n    class ProblemShape,\n    class Element,\n    class ElementAcc,\n    class TileShape,\n    class Mask\n>\nstruct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {\n\n  using TileShapeQ = decltype(get<0>(TileShape{}));\n  using TileShapeK = decltype(get<1>(TileShape{}));\n  using TileShapeDQK = decltype(get<2>(TileShape{}));\n  using TileShapeDVO = decltype(get<3>(TileShape{}));\n\n  using TmemAllocator = cute::TMEM::Allocator1Sm;\n  struct TmemAllocation {\n    static constexpr uint32_t kDK = 0;                     // TileShapeK x TileShapeDQK x acc\n    static constexpr uint32_t kDV = kDK + TileShapeDQK{};  // TileShapeK x TileShapeDVO x acc\n    static constexpr uint32_t kDQ = kDV + TileShapeDVO{};  // TileShapeQ x TileShapeDQK x acc\n    static constexpr uint32_t kDP = kDQ;                   // TileShapeK x TileShapeQ   x inp\n    static constexpr uint32_t kS = kDQ + 65536 * 16;\n    static constexpr uint32_t kP = kS;\n    static constexpr uint32_t kTotal = kDQ + TileShapeDQK{};\n  };\n\n  static_assert(\n      static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,\n      \"using too much tmem\"\n  );\n\n  enum class WarpRole {\n    Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4\n  };\n\n  static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;\n  static constexpr int kNumComputeWarps = 8;\n  static constexpr int kNumReduceWarps = 4;\n\n  static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp;\n  static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, \"TileShapeQ must be divisible by NumThreadsPerWarp\");\n  CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {\n    return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);\n  }\n\n  struct RegisterAllocation {\n    static constexpr int kWarpgroup0 = 160-8;\n    static constexpr int kWarpgroup1 = 128;\n    static constexpr int kWarpgroup2 = 96;\n    static constexpr int kReduce = kWarpgroup0;\n    static constexpr int kCompute = kWarpgroup1;\n    static constexpr int kMma = kWarpgroup2;\n    static constexpr int kEmpty = kWarpgroup2;\n    static constexpr int kLoad = kWarpgroup2;\n\n    static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);\n  };\n\n  using ArchTag = cutlass::arch::Sm100;\n\n  using ClusterShape = Shape<_1, _1, _1>;\n  using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;\n\n  static constexpr int MinBlocksPerMultiprocessor = 1;\n  static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;\n  static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;\n\n  static constexpr int Alignment = 128 / sizeof_bits_v<Element>;\n  static constexpr int kStages = 2;\n\n  using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;\n  using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;\n\n  // compute S\n  using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousK, Alignment,\n      ElementAcc,\n      Shape<TileShapeQ, TileShapeK, TileShapeDQK>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeQK = typename CollectiveMmaQK::TileShape;\n  using TiledMmaQK = typename CollectiveMmaQK::TiledMma;\n\n  // compute dP\n  using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousK, Alignment,\n      ElementAcc,\n      Shape<TileShapeQ, TileShapeK, TileShapeDVO>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeDOV = typename CollectiveMmaDOV::TileShape;\n  using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma;\n\n  // compute dV\n  using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // needs to match ordering of S calculation\n      Element, TensorStrideContiguousK, Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeDVO, TileShapeQ>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapePDO = typename CollectiveMmaPDO::TileShape;\n  using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma;\n\n  // compute dK\n  using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // somewhat arbitrary since we dump to smem, need to agree with the next one\n      Element, TensorStrideContiguousK , Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeK, TileShapeDQK, TileShapeQ>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;\n  using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;\n\n  // compute dQ\n  using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<\n      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,\n      // somewhat arbitrary since we dump to smem, need to agree with the previous one\n      Element, TensorStrideContiguousMN, Alignment,\n      Element, TensorStrideContiguousMN, Alignment,\n      ElementAcc,\n      Shape<TileShapeQ, TileShapeDQK, TileShapeK>,\n      ClusterShape, cutlass::gemm::collective::StageCount<kStages>,\n      Schedule>::CollectiveOp;\n  using TileShapeDSK = typename CollectiveMmaDSK::TileShape;\n  using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;\n\n  // pipelines are named Pipeline<Producer><Consumer><Resource>\n  static constexpr int kStagesComputeSmem = 1;\n  using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;\n  using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;\n  using PipelineLoadComputeLSE = PipelineAsync<1>;\n  using PipelineLoadComputeSumOdO = PipelineAsync<1>;\n  using PipelineMmaComputeS = PipelineUmmaAsync<1>;\n  using PipelineMmaComputeDP = PipelineUmmaAsync<1>;\n  using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;\n  using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;\n  using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;\n  using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;\n  static constexpr int kStagesReduceTmaStore = 2;\n  using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;\n\n  struct PipelineStorage {\n    alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;\n    alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;\n    alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;\n    alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;\n    alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;\n    alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;\n    alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;\n    alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;\n    alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;\n    alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;\n  };\n\n  template<class Layout, class Stages = _1>\n  static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {\n    return composition(layout, make_tuple(_, _, _, make_layout(stages)));\n  }\n\n  using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{}));\n  using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{}));\n  using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{}));\n  using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{}));\n  using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));\n  using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;\n  using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;\n\n  using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));\n  using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));\n  using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));\n  using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));\n  using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{}));\n  using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{}));\n\n  using TileShapeDQ = _32;\n  using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<\n      cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ\n  >());\n  using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;\n  using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));\n\n  struct TensorStorage {\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;\n    };\n    alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;\n    };\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;\n    };\n    union {\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;\n    };\n    union{\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;\n      alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutPT>> smem_p_t;\n    };\n    alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;\n    alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;\n    alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;\n  };\n\n  static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);\n\n  static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);\n  static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);\n\n  struct SharedStorage {\n    TensorStorage tensors;\n    PipelineStorage pipelines;\n    uint32_t tmem_base_ptr;\n  };\n\n  // this is tight enough that it won't work with sizeof due to padding for alignment\n  static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);\n  static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, \"using too much smem\");\n\n  using TensorStride = TensorStrideContiguousK;  // S D (H B)\n  using RowTensorStride = Stride<_1, Stride<int, int>>;    // S (H B)\n\n  struct MainloopArguments {\n    const Element* ptr_q;\n    TensorStride stride_q;\n    const Element* ptr_k;\n    TensorStride stride_k;\n    const Element* ptr_v;\n    TensorStride stride_v;\n    const Element* ptr_do;\n    TensorStride stride_do;\n\n    const ElementAcc* ptr_lse;\n    RowTensorStride stride_lse;\n\n    const ElementAcc* ptr_sum_odo;\n    RowTensorStride stride_sum_odo;\n\n    ElementAcc* ptr_dq_acc;\n    TensorStride stride_dq_acc;\n\n    ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});\n  };\n\n  using TMA_K = typename CollectiveMmaQK::Params::TMA_B;\n  using TMA_V = typename CollectiveMmaDOV::Params::TMA_B;\n  using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;\n  using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A;\n\n  using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},\n      make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),\n      SmemLayoutDQ{}(_, _, _0{})\n  ));\n\n  struct MainloopParams {\n    TMA_K tma_load_k;\n    TMA_V tma_load_v;\n    TMA_Q tma_load_q;\n    TMA_DO tma_load_do;\n    TMA_DQ tma_red_dq;\n  };\n\n  struct EpilogueArguments {\n    Element* ptr_dk;\n    TensorStride stride_dk;\n    Element* ptr_dv;\n    TensorStride stride_dv;\n  };\n\n  struct Arguments {\n    ProblemShape problem_shape;\n    MainloopArguments mainloop;\n    EpilogueArguments epilogue;\n    KernelHardwareInfo hw_info;\n  };\n\n  struct Params {\n    ProblemShape problem_shape;\n    MainloopArguments mainloop;\n    MainloopParams mainloop_params;\n    EpilogueArguments epilogue;\n    KernelHardwareInfo hw_info;\n  };\n\n\n  static bool can_implement(Arguments const& args) {\n    auto [Q, K, D, D_VO, HB] = args.problem_shape;\n    auto [H, B] = HB;\n    if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) {\n      return false;\n    }\n    if (D % Alignment != 0 || D_VO % Alignment != 0) {\n      return false;\n    }\n    return true;\n  }\n\n\n  static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {\n    return Status::kSuccess;\n  }\n\n\n  static Params to_underlying_arguments(Arguments const& args, void*) {\n    auto [Q_, K_, D, D_VO, HB] = args.problem_shape;\n    int Q = Q_;\n    int K = K_;\n\n    if constexpr (is_variable_length_v<decltype(Q_)>) {\n      Q = Q_.total_length;\n    }\n    if constexpr (is_variable_length_v<decltype(K_)>) {\n      K = K_.total_length;\n    }\n\n    auto params_kq = CollectiveMmaQK::to_underlying_arguments(\n      make_shape(Q, K, D, HB),\n      typename CollectiveMmaQK::Arguments {\n        args.mainloop.ptr_q, args.mainloop.stride_q,\n        args.mainloop.ptr_k, args.mainloop.stride_k,\n      }, /*workspace=*/nullptr);\n\n    auto params_vdo = CollectiveMmaDOV::to_underlying_arguments(\n      make_shape(Q, K, D_VO, HB),\n      typename CollectiveMmaDOV::Arguments {\n        args.mainloop.ptr_do, args.mainloop.stride_do,\n        args.mainloop.ptr_v, args.mainloop.stride_v,\n      }, /*workspace=*/nullptr);\n\n    TMA_DQ tma_red_dq = make_tma_copy(\n        SM90_TMA_REDUCE_ADD{},\n        make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),\n        SmemLayoutDQ{}(_, _, _0{})\n    );\n\n    return Params{\n      args.problem_shape,\n      args.mainloop,\n      MainloopParams{\n        params_kq.tma_load_b,\n        params_vdo.tma_load_b,\n        params_kq.tma_load_a,\n        params_vdo.tma_load_a,\n        tma_red_dq\n      },\n      args.epilogue,\n      args.hw_info\n    };\n  }\n\n\n  template<class T>\n  static CUTLASS_DEVICE auto quantize(T const& input) {\n    constexpr int AlignmentS = 4;\n    auto output = make_tensor<Element>(shape(input));\n    auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);\n    auto output_vec = recast<Array<Element, AlignmentS>>(output);\n\n    cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(input_vec); i++) {\n      output_vec(i) = epilogue_op(input_vec(i));\n    }\n\n    return output;\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void load(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      MainloopParams const& mainloop_params,\n      TensorStorage& shared_tensors,\n      PipelineLoadMmaQ& pipeline_load_mma_q,\n      typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,\n      PipelineLoadMmaDO& pipeline_load_mma_do,\n      typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,\n      PipelineLoadComputeLSE& pipeline_load_compute_lse,\n      typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,\n      PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,\n      typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    using X = Underscore;\n\n    uint16_t mcast_mask = 0;\n\n    auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));\n    auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));\n    auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));\n    auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));\n\n    auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);\n    auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);\n    auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);\n    auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);\n\n    auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});\n    auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});\n    auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step<X, _1, _1>{});\n    auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{});\n\n    ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{});\n    ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{});\n\n    auto tSTgK = cta_mma_kq.partition_B(gK);\n    auto tSTgQ = cta_mma_kq.partition_A(gQ);\n    auto tDPTgV = cta_mma_vdo.partition_B(gV);\n    auto tDPTgDO = cta_mma_vdo.partition_A(gDO);\n\n    auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});\n    auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});\n    auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});\n    auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});\n\n    auto [tKgK_mkl, tKsK] = tma_partition(\n        mainloop_params.tma_load_k, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sK), group_modes<0,3>(tSTgK));\n    auto [tQgQ_mkl, tQsQ] = tma_partition(\n        mainloop_params.tma_load_q, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));\n    auto [tVgV_mkl, tVsV] = tma_partition(\n        mainloop_params.tma_load_v, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));\n    auto [tDOgDO_mkl, tDOsDO] = tma_partition(\n        mainloop_params.tma_load_do, _0{}, make_layout(_1{}),\n        group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));\n\n    // set up lse and sum_odo\n\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);\n    auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);\n\n    pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);\n\n    // load K\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),\n          tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),\n          tKsK(_, _0{})\n      );\n    }\n\n    // load Q\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),\n          tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),\n          tQsQ(_, pipeline_load_mma_q_producer_state.index())\n      );\n    }\n\n    ++pipeline_load_mma_q_producer_state;\n\n    pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);\n\n    // load LSE\n    // 32 threads loading kLoadPerThread * 32 values of 32b each\n\n    int thread_idx = threadIdx.x % NumThreadsPerWarp;\n    int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;\n    int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;\n    auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);\n    for (int i = 0; i < kLoadPerThread; i++) {\n      cutlass::arch::cp_async_zfill<4>(\n          shared_tensors.smem_lse.begin() + smem_idx + i,\n          &mLSE(gmem_idx + i, blk_coord_batch),\n          gmem_idx + i < Q\n      );\n    }\n\n    pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);\n    ++pipeline_load_compute_lse_producer_state;\n\n\n    pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);\n    tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);\n\n    pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);\n\n    // load V\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),\n          tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),\n          tVsV(_, _0{})\n      );\n    }\n\n    // load dO\n    if (cute::elect_one_sync()) {\n      cute::copy(\n          mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),\n          tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),\n          tDOsDO(_, pipeline_load_mma_do_producer_state.index())\n      );\n    }\n\n    ++pipeline_load_mma_do_producer_state;\n\n    pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);\n\n    // load sum_OdO\n    smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;\n    gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;\n    auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);\n    for (int i = 0; i < kLoadPerThread; i++) {\n      cutlass::arch::cp_async_zfill<4>(\n          shared_tensors.smem_sum_odo.begin() + smem_idx + i,\n          &mSumOdO(gmem_idx + i, blk_coord_batch),\n          gmem_idx + i < Q\n      );\n    }\n\n    pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);\n    ++pipeline_load_compute_sum_odo_producer_state;\n\n    iter_count -= 1;\n    iter_index += 1;\n\n    while (iter_count > 0) {\n      pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);\n      tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);\n\n      // load Q\n      if (cute::elect_one_sync()) {\n        cute::copy(\n            mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),\n            tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),\n            tQsQ(_, pipeline_load_mma_q_producer_state.index())\n        );\n      }\n\n      ++pipeline_load_mma_q_producer_state;\n\n      pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);\n\n      // load LSE\n      smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;\n      gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;\n      for (int i = 0; i < kLoadPerThread; i++) {\n        cutlass::arch::cp_async_zfill<4>(\n            shared_tensors.smem_lse.begin() + smem_idx + i,\n            &mLSE(gmem_idx + i, blk_coord_batch),\n            gmem_idx + i < Q\n        );\n      }\n\n      pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);\n      ++pipeline_load_compute_lse_producer_state;\n\n      pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);\n      tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);\n\n      // load dO\n      if (cute::elect_one_sync()) {\n        cute::copy(\n            mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),\n            tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),\n            tDOsDO(_, pipeline_load_mma_do_producer_state.index())\n        );\n      }\n\n      ++pipeline_load_mma_do_producer_state;\n\n      pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);\n\n      // load sum_OdO\n      smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;\n      gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;\n      for (int i = 0; i < kLoadPerThread; i++) {\n        cutlass::arch::cp_async_zfill<4>(\n            shared_tensors.smem_sum_odo.begin() + smem_idx + i,\n            &mSumOdO(gmem_idx + i, blk_coord_batch),\n            gmem_idx + i < Q\n        );\n      }\n\n      pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);\n      ++pipeline_load_compute_sum_odo_producer_state;\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n  }\n\n\n  template<class BlkCoord, class ProblemShape_>\n  CUTLASS_DEVICE void mma(\n      BlkCoord const& blk_coord,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      TensorStorage& shared_tensors,\n      PipelineLoadMmaQ& pipeline_load_mma_q,\n      typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,\n      PipelineLoadMmaDO& pipeline_load_mma_do,\n      typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,\n      PipelineMmaComputeS& pipeline_mma_compute_s,\n      typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,\n      PipelineMmaComputeDP& pipeline_mma_compute_dp,\n      typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,\n      PipelineMmaReduceDQ& pipeline_mma_reduce_dq,\n      typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,\n      PipelineComputeMmaP& pipeline_compute_mma_p,\n      typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,\n      PipelineComputeMmaDS& pipeline_compute_mma_ds,\n      typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});\n    auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});\n    auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});\n    auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});\n\n    auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});\n    auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});\n    auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});\n    auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});\n    auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{});\n    auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});\n\n    Tensor tSTrK = TiledMmaQK::make_fragment_B(sK);\n    Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ);\n\n    Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV);\n    Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO);\n\n    Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);\n    Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);\n\n    Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);\n    Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);\n\n    Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP);\n    Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);\n\n    TiledMmaQK tiled_mma_qk;\n    TiledMmaDOV tiled_mma_dov;\n    TiledMmaDSK tiled_mma_dsk;\n    TiledMmaDSQ tiled_mma_dsq;\n    TiledMmaPDO tiled_mma_pdo;\n\n    tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;\n    tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;\n\n    Tensor tSTtST =  partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{}));\n    tSTtST.data() = TmemAllocation::kS;\n\n    Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{}));\n    tDPTtDPT.data() = TmemAllocation::kDP;\n\n    Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));\n    tDQtDQ.data() = TmemAllocation::kDQ;\n\n    Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));\n    tDKtDK.data() = TmemAllocation::kDK;\n\n    Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));\n    tDVtDV.data() = TmemAllocation::kDV;\n\n    auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;\n\n    pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);\n    pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);\n\n    // S = Q*K\n    tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {\n      cute::gemm(tiled_mma_qk,\n                 tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),\n                 tSTrK(_,_,k_block,_0{}),\n                 tSTtST);\n      tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    ++pipeline_load_mma_q_consumer_state;\n\n    pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);\n    ++pipeline_mma_compute_s_producer_state;\n\n    pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);\n\n    pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);\n    pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);\n\n    // dP = dO*V\n    tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {\n      cute::gemm(tiled_mma_dov,\n                 tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                 tDPTrV(_,_,k_block,_0{}),\n                 tDPTtDPT);\n      tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);\n    ++pipeline_mma_compute_dp_producer_state;\n\n    pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);\n\n    // dV = P*dO\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {\n      cute::gemm(tiled_mma_pdo,\n                 tDVrP(_,_,k_block,_0{}),\n                 tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                 tDVtDV);\n      tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);\n    ++pipeline_compute_mma_p_consumer_state;\n\n    pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);\n    ++pipeline_load_mma_do_consumer_state;\n\n    iter_count -= 1;\n\n    // in tmem, S & P overlap\n    // and dP and dQ overlap\n    // so we need to acquire dQ and dP at the same time\n    while (iter_count > 0) {\n      pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);\n      pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);\n\n      // S = Q*K\n      tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {\n        cute::gemm(tiled_mma_qk,\n                   tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),\n                   tSTrK(_,_,k_block,_0{}),\n                   tSTtST);\n        tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      ++pipeline_load_mma_q_consumer_state;\n\n      pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);\n      ++pipeline_mma_compute_s_producer_state;\n\n      pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);\n\n      // we need to acquire dP here, because tmem dQ == tmem dP\n      pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);\n\n      // dQ = dS*K\n      tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {\n        cute::gemm(tiled_mma_dsk,\n                   tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                   tDQrKT(_,_,k_block,_0{}),\n                   tDQtDQ);\n        tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);\n      ++pipeline_mma_reduce_dq_producer_state;\n\n      // dK = dS*Q\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {\n        cute::gemm(tiled_mma_dsq,\n                   tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                   tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),\n                   tDKtDK);\n        tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);\n      ++pipeline_load_mma_q_release_state;\n\n      pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);\n      ++pipeline_compute_mma_ds_consumer_state;\n\n      // we grab dq here, because in tmem dq == dp\n      pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);\n\n      pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);\n\n      // dP = dO*V\n      tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {\n        cute::gemm(tiled_mma_dov,\n                   tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                   tDPTrV(_,_,k_block,_0{}),\n                   tDPTtDPT);\n        tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);\n      ++pipeline_mma_compute_dp_producer_state;\n\n      pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);\n\n      // dV = P*dO\n      CUTLASS_PRAGMA_UNROLL\n      for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {\n        cute::gemm(tiled_mma_pdo,\n                   tDVrP(_,_,k_block,_0{}),\n                   tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),\n                   tDVtDV);\n        tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;\n      }\n\n      pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);\n      ++pipeline_compute_mma_p_consumer_state;\n\n      pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);\n      ++pipeline_load_mma_do_consumer_state;\n\n      iter_count -= 1;\n    }\n\n    // signal to the epilogue that dV is ready\n    pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);\n    pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);\n    ++pipeline_mma_compute_dkdv_producer_state;\n\n    pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);\n\n    pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);\n\n    // dK = dS*Q\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {\n      cute::gemm(tiled_mma_dsq,\n                 tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                 tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),\n                 tDKtDK);\n      tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    // signal to epilgue that dK is ready\n    pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);\n    ++pipeline_mma_compute_dkdv_producer_state;\n\n    // we've already acquired mma_reduce_dq in the loop\n\n    // dQ = dS*K\n    tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {\n      cute::gemm(tiled_mma_dsk,\n                 tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),\n                 tDQrKT(_,_,k_block,_0{}),\n                 tDQtDQ);\n      tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;\n    }\n\n    pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);\n    ++pipeline_mma_reduce_dq_producer_state;\n\n    pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);\n    ++pipeline_load_mma_q_release_state;\n\n    pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);\n    ++pipeline_compute_mma_ds_consumer_state;\n  }\n\n\n\n  template<class TensorG, class TensorR, class TensorC, class TensorShape>\n  CUTLASS_DEVICE void store(\n      TensorG gmem,\n      TensorR const& regs,\n      TensorC const& coord,\n      TensorShape const& tensor_shape) {\n  \n    Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });\n\n    auto copy_op = make_cotiled_copy(\n        Copy_Atom<UniversalCopy<uint128_t>, Element>{},\n        make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),\n        regs.layout()\n    );\n    auto thr_copy = copy_op.get_slice(_0{});\n\n    Tensor quantized_regs = quantize(regs);\n    Tensor tCr = thr_copy.partition_S(quantized_regs);\n    Tensor tCg = thr_copy.partition_D(gmem);\n    Tensor tPc = thr_copy.partition_D(preds);\n \n    copy_if(copy_op, tPc, tCr, tCg);\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void epilogue_clear(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);\n    auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);\n    auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDK = domain_offset(\n        make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapeDSQ{}))\n    );\n\n    auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);\n    auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);\n    auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDV = domain_offset(\n        make_coord(blk_coord_k * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapePDO{}))\n    );\n    \n    for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {\n      if (elem_less(cDK(i), select<1,2>(problem_shape))) {\n        gDK(i) = Element(0);\n      }\n    }\n    for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {\n      if (elem_less(cDV(i), select<1,3>(problem_shape))) {\n        gDV(i) = Element(0);\n      }\n    }\n\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void epilogue(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    auto load_op = SM100_TMEM_LOAD_32dp32b16x{};\n\n    auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});\n    tDKtDK.data() = TmemAllocation::kDK;\n\n    auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);\n    auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);\n    auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDK = domain_offset(\n        make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapeDSQ{}))\n    );\n\n    constexpr int kNumWarpgroups = kNumComputeWarps / 4;\n    int dp_idx = threadIdx.x % 128;\n    int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;\n\n    auto split_wg = [&](auto const& t) {\n      if constexpr (decltype(rank(t))::value == 3) {\n        auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));\n        return p(_, _, make_coord(wg_idx, _));\n      }\n      else {\n        auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));\n        return p(_, _, _, make_coord(wg_idx, _));\n      }\n    };\n\n    auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);\n    auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);\n\n    Tensor tTR_cDK   = split_wg(thread_t2r_dk.partition_D(cDK));\n    Tensor tTR_gDK   = split_wg(thread_t2r_dk.partition_D(gDK));\n    Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));\n    Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));\n\n    auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{});\n    tDVtDV.data() = TmemAllocation::kDV;\n\n    auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);\n    auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);\n    auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})\n        (_, _, blk_coord_k, _0{}, blk_coord_batch);\n\n    Tensor cDV = domain_offset(\n        make_coord(blk_coord_k * TileShapeK{}, _0{}),\n        make_identity_tensor(take<0,2>(TileShapePDO{}))\n    );\n\n    auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);\n    auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);\n\n    Tensor tTR_cDV   = split_wg(thread_t2r_dv.partition_D(cDV));\n    Tensor tTR_gDV   = split_wg(thread_t2r_dv.partition_D(gDV));\n    Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));\n    Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));\n\n    pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);\n\n    // load tDVtDV\n    cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);\n\n    // store tDVgDV\n    store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));\n\n    cutlass::arch::fence_view_async_tmem_load();\n    pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);\n    ++pipeline_mma_compute_dkdv_consumer_state;\n\n    pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);\n\n    // load tDKtDK\n    cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);\n\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < size(tTR_rDK); i++) {\n      tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);\n    }\n\n    // store tDKgDK\n    store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));\n\n    cutlass::arch::fence_view_async_tmem_load();\n    pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);\n    ++pipeline_mma_compute_dkdv_consumer_state;\n\n  }\n\n\n  template<class BlkCoord, class BlkOffset, class ProblemShape_>\n  CUTLASS_DEVICE void compute(\n      BlkCoord const& blk_coord,\n      BlkOffset const& blk_offset,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      EpilogueArguments const& epilogue_args,\n      TensorStorage& shared_tensors,\n      PipelineLoadComputeLSE& pipeline_load_compute_lse,\n      typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,\n      PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,\n      typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,\n      PipelineMmaComputeS& pipeline_mma_compute_s,\n      typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,\n      PipelineMmaComputeDP& pipeline_mma_compute_dp,\n      typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,\n      PipelineComputeMmaP& pipeline_compute_mma_p,\n      typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,\n      PipelineComputeMmaDS& pipeline_compute_mma_ds,\n      typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,\n      PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,\n      typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {\n\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    // in tmem, S & P overlap\n    // and dP and dQ overlap\n\n    // there are two compute wg's that cooperatively compute softmax\n    // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc\n\n    auto load_op = SM100_TMEM_LOAD_16dp32b32x{};\n\n    Tensor tSTtST =  partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{});\n    tSTtST.data() = TmemAllocation::kS;\n\n    Tensor tDPTtDPT =  partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{});\n    tDPTtDPT.data() = TmemAllocation::kDP;\n\n    Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{}));\n    Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{}));\n    Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{}));\n\n    constexpr int kNumWarpgroups = kNumComputeWarps / 4;\n    int dp_idx = threadIdx.x % 128;\n    int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;\n    auto tiled_t2r = make_tmem_copy(load_op, tSTtST);\n    auto thread_t2r = tiled_t2r.get_slice(dp_idx);\n\n    auto split_wg = [&](auto const& t) {\n      if constexpr (decltype(size<1>(t))::value > 1) {\n        if constexpr (decltype(rank(t))::value == 3) {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));\n          return p(_, make_coord(wg_idx, _), _);\n        }\n        else {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));\n          return p(_, make_coord(wg_idx, _), _, _);\n        }\n      }\n      else {\n        if constexpr (decltype(rank(t))::value == 3) {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));\n          return p(_, _, make_coord(wg_idx, _));\n        }\n        else {\n          auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));\n          return p(_, _, _, make_coord(wg_idx, _));\n        }\n      }\n    };\n\n    Tensor tTR_cST_p = thread_t2r.partition_D(cST);\n    Tensor tTR_cST   = split_wg(tTR_cST_p);\n    Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));\n    Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));\n\n    Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);\n    Tensor tTR_cPT_p = thread_t2r.partition_D(cPT);\n    Tensor tTR_cDPT = split_wg(tTR_cDPT_p);\n    Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));\n    Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));\n\n    Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});\n    Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});\n\n    bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);\n    int last_iter = iter_count - 1 + iter_index;\n\n    CUTLASS_PRAGMA_NO_UNROLL\n    while (iter_count > 0) {\n      // wait for S and P\n      pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);\n      pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);\n      // wait for LSE\n      pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);\n\n      auto dispatch_bool = [](bool b, auto fn) {\n        if (b) {\n          fn(cute::true_type{});\n        }\n        else {\n          fn(cute::false_type{});\n        }\n      };\n\n      bool leading_causal_masking = false;\n      if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {\n        leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));\n      } else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {\n        int offset = get<1>(problem_shape) - get<0>(problem_shape);\n        int kv_left = get<1>(blk_coord) * TileShapeK{};\n        int kv_right = kv_left + TileShapeK{} - 1;\n        int q_left = iter_index * TileShapeQ{} + offset;\n        int q_right = q_left + TileShapeQ{} - 1;\n\n        leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));\n      }\n      bool trailing_residual_masking = false;\n      if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {\n        trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);\n      }\n\n      dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {\n\n        // compute P = softmax(S, LSE)\n        cute::copy(tiled_t2r, tTR_tST, tTR_rST);\n\n        if constexpr (decltype(is_masked_tile)::value) {\n          Mask{}.apply_mask(tTR_rST, [&](int i) {\n            auto c_transpose = tTR_cST(i);\n            return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{});\n          }, problem_shape);\n        }\n\n        ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);\n        float2 softmax_scale_log2_e;\n        softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;\n        softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;\n\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < size(tTR_rST); i += 2) {\n          float2 acc;\n          float2 lse;\n          float2 out;\n          acc.x = tTR_rST(i);\n          acc.y = tTR_rST(i + 1);\n          lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());\n          lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());\n          cute::fma(out, softmax_scale_log2_e, acc, lse);\n          tTR_rST(i) = ::exp2f(out.x);\n          tTR_rST(i+1) = ::exp2f(out.y);\n        }\n\n        auto tRT_rST = quantize(tTR_rST);\n\n        Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})\n          (_, _, _, pipeline_compute_mma_p_producer_state.index());\n\n        cutlass::arch::fence_view_async_tmem_load();\n        cutlass::arch::NamedBarrier(\n          kNumComputeWarps * NumThreadsPerWarp,\n          cutlass::arch::ReservedNamedBarriers::TransformBarrier\n        ).arrive_and_wait();\n\n        auto sP_pi = as_position_independent_swizzle_tensor(sP);\n\n        auto thread_layout = make_ordered_layout(\n            make_shape(_64{}, _32{}, _2{}, _2{}),\n            make_stride(_3{}, _0{}, _1{}, _2{})\n            );\n        auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p)));\n        auto sP_pi_slice = split_wg(sP_pi_slice_p);\n        copy_aligned(tRT_rST, sP_pi_slice);\n      });\n\n      // notify for P\n      cutlass::arch::fence_view_async_shared();\n      pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);\n      ++pipeline_compute_mma_p_producer_state;\n      // release S\n      pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);\n      ++pipeline_mma_compute_s_consumer_state;\n      // release LSE\n      pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);\n      ++pipeline_load_compute_lse_consumer_state;\n\n      // wait for OdO\n      pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);\n      // wait for dP\n      pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);\n\n      // wait for dS\n      // in principle, we could defer waiting for dS, and move in the freeing of dP\n      // however, that would force us to keep dS in registers longer\n      pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);\n\n      // compute dS = dsoftmax(P, dP, sum_OdO)\n      cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size(tTR_rDPT); i += 2) {\n        float2 st;\n        st.x = tTR_rST(i);\n        st.y = tTR_rST(i+1);\n        float2 dpt;\n        dpt.x = tTR_rDPT(i);\n        dpt.y = tTR_rDPT(i+1);\n        float2 odo;\n        odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());\n        odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());\n        float2 dif;\n        // sum odo is negated during preprocess\n        cute::add(dif, dpt, odo);\n        float2 out;\n        cute::mul(out, dif, st);\n        tTR_rDPT(i) = out.x;\n        tTR_rDPT(i+1) = out.y;\n      }\n\n      auto tTR_rDST = quantize(tTR_rDPT);\n\n      // release dP\n      cutlass::arch::fence_view_async_tmem_load();\n      pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);\n      ++pipeline_mma_compute_dp_consumer_state;\n\n      Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{})\n          (_, _, _, pipeline_compute_mma_ds_producer_state.index());\n\n      auto thread_layout = make_ordered_layout(\n          make_shape(_64{}, _32{}, _2{}, _2{}),\n          make_stride(_3{}, _0{}, _1{}, _2{})\n          );\n      auto sDS_pi = as_position_independent_swizzle_tensor(sDS);\n      auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape      (tTR_cDPT_p)));\n      auto sDS_pi_slice = split_wg(sDS_pi_slice_p);\n\n      copy_aligned(tTR_rDST, sDS_pi_slice);\n\n      // notify for dS\n      cutlass::arch::fence_view_async_shared();\n      pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);\n      ++pipeline_compute_mma_ds_producer_state;\n      // release OdO\n      pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);\n      ++pipeline_load_compute_sum_odo_consumer_state;\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n\n    epilogue(\n        blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,\n        pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state\n    );\n  }\n\n  template<class BlkCoord, class ProblemShape_>\n  CUTLASS_DEVICE void reduce(\n      BlkCoord const& blk_coord,\n      ProblemShape_ const& problem_shape,\n      int iter_index,\n      int iter_count,\n      MainloopArguments const& mainloop_args,\n      MainloopParams const& mainloop_params,\n      TensorStorage& shared_tensors,\n      PipelineMmaReduceDQ& pipeline_mma_reduce_dq,\n      typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,\n      PipelineReduceTmaStore& pipeline_reduce_tma_store,\n      typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {\n\n    using X = Underscore;\n\n    auto [Q, K, D, D_VO, HB] = problem_shape;\n\n    auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;\n\n    // must match TileShapeDQ\n    auto load_op = SM100_TMEM_LOAD_16dp32b16x{};\n\n    auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});\n    tDQtDQ.data() = TmemAllocation::kDQ;\n\n    Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));\n    auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{})\n        (_, _, _, _0{}, blk_coord_batch);\n\n    Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));\n\n    Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});\n\n    int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp);\n    auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);\n    auto thread_t2r = tiled_t2r.get_slice(thread_idx);\n\n    Tensor tTR_cDQ   = thread_t2r.partition_D(cDQ);\n    Tensor tTR_gDQ   = thread_t2r.partition_D(gDQ);\n    Tensor tTR_sDQ   = thread_t2r.partition_D(sDQ);\n    Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);\n\n    auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});\n\n    Tensor tDQsDQ = block_tma.partition_S(sDQ);\n    Tensor tDQcDQ = block_tma.partition_S(cDQ);\n    Tensor tDQgDQ = block_tma.partition_D(gDQ);\n\n    int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;\n\n    while (iter_count > 0) {\n      pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);\n\n      Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));\n\n      // load dQ from tmem to rmem\n      cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);\n\n      cutlass::arch::fence_view_async_tmem_load();\n      pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);\n      ++pipeline_mma_reduce_dq_consumer_state;\n\n      // we don't have enough smem to dump it all to smem, so we do it in stages\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < size<2>(tTR_cDQ); i++) {\n        if (lane_predicate) {\n          pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);\n        }\n        // wait in all threads for the acquire to complete\n        cutlass::arch::NamedBarrier(\n            kNumReduceWarps * NumThreadsPerWarp,\n            cutlass::arch::ReservedNamedBarriers::TransposeBarrier\n        ).arrive_and_wait();\n\n        cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));\n\n        // wait for the stores to all be visible to the TMA\n        cutlass::arch::fence_view_async_shared();\n        cutlass::arch::NamedBarrier(\n            kNumReduceWarps * NumThreadsPerWarp,\n            cutlass::arch::ReservedNamedBarriers::TransposeBarrier\n        ).arrive_and_wait();\n        if (lane_predicate) {\n          // launch tma store\n          copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));\n          pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);\n        }\n\n        ++pipeline_reduce_tma_store_producer_state;\n      }\n\n      iter_count -= 1;\n      iter_index += 1;\n    }\n  }\n\n\n  CUTLASS_DEVICE void operator()(Params const& params, char* smem) {\n#if defined(KERUTILS_ENABLE_SM100A)\n    int warp_idx = cutlass::canonical_warp_idx_sync();\n    auto role = warp_idx_to_role(warp_idx);\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    if (role == WarpRole::Load && lane_predicate) {\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());\n      prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());\n    }\n\n    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);\n\n    int initializing_warp = 0;\n    typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Mma) {\n      pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;\n    }\n    pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    // Also loads K in the first iteration\n    pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;\n    pipeline_load_mma_q_params.initializing_warp = initializing_warp++;\n    PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Mma) {\n      pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;\n    }\n    pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    // Also loads V in the first iteration\n    pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;\n    pipeline_load_mma_do_params.initializing_warp = initializing_warp++;\n    PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;\n    }\n    pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;\n    pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;\n    pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;\n    PipelineLoadComputeLSE pipeline_load_compute_lse(\n      shared_storage.pipelines.load_compute_lse,\n      pipeline_load_compute_lse_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;\n    }\n    pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;\n    pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;\n    pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;\n    PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(\n      shared_storage.pipelines.load_compute_sum_odo,\n      pipeline_load_compute_sum_odo_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeS pipeline_mma_compute_s(\n      shared_storage.pipelines.mma_compute_s,\n      pipeline_mma_compute_s_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeDP pipeline_mma_compute_dp(\n      shared_storage.pipelines.mma_compute_dp,\n      pipeline_mma_compute_dp_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Reduce) {\n      pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;\n    }\n    pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;\n    PipelineMmaReduceDQ pipeline_mma_reduce_dq(\n      shared_storage.pipelines.mma_reduce_dq,\n      pipeline_mma_reduce_dq_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;\n    if (role == WarpRole::Mma) {\n      pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;\n    }\n    pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_compute_mma_p_params.consumer_arv_count = 1;\n    pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;\n    PipelineComputeMmaP pipeline_compute_mma_p(\n      shared_storage.pipelines.compute_mma_p,\n      pipeline_compute_mma_p_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;\n    if (role == WarpRole::Mma) {\n      pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;\n    }\n    pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_compute_mma_ds_params.consumer_arv_count = 1;\n    pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;\n    PipelineComputeMmaDS pipeline_compute_mma_ds(\n      shared_storage.pipelines.compute_mma_ds,\n      pipeline_compute_mma_ds_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;\n    if (role == WarpRole::Mma) {\n      pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Compute) {\n      pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;\n    }\n    pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;\n    pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;\n    PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(\n      shared_storage.pipelines.mma_compute_dkdv,\n      pipeline_mma_compute_dkdv_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n    PipelineReduceTmaStore pipeline_reduce_tma_store;\n\n    TmemAllocator tmem_allocator;\n\n    pipeline_init_arrive_relaxed(size(ClusterShape{}));\n\n    pipeline_load_mma_q.init_masks(ClusterShape{});\n    pipeline_load_mma_do.init_masks(ClusterShape{});\n    pipeline_mma_compute_s.init_masks(ClusterShape{});\n    pipeline_mma_compute_dp.init_masks(ClusterShape{});\n    pipeline_mma_reduce_dq.init_masks(ClusterShape{});\n    pipeline_compute_mma_p.init_masks(ClusterShape{});\n    pipeline_compute_mma_ds.init_masks(ClusterShape{});\n    pipeline_mma_compute_dkdv.init_masks(ClusterShape{});\n\n    typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;\n    typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;\n    typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;\n    typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;\n    typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;\n    typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;\n    typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;\n    typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;\n    typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;\n    typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;\n\n    auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();\n    auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();\n    auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();\n    auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();\n    auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();\n    auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();\n    auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();\n    auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();\n    auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();\n    auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();\n    auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();\n\n    pipeline_init_wait(size(ClusterShape{}));\n\n    auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));\n    auto [problem_shape, blk_offset] = apply_variable_length_offset(\n        params.problem_shape,\n        blk_coord\n    );\n    int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});\n    int iter_start = 0;\n    if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {\n      iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};\n    } else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {\n      int offset = get<1>(problem_shape) - get<0>(problem_shape);\n      iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});\n    }\n    if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {\n      return;\n    }\n    iter_count -= iter_start;\n\n    if (iter_count <= 0) {\n      epilogue_clear(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          params.mainloop,\n          params.epilogue\n      );\n      return;\n    }\n\n    if (role == WarpRole::Load) {\n      warpgroup_reg_set<RegisterAllocation::kLoad>();\n\n      load(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.mainloop_params,\n          shared_storage.tensors,\n          pipeline_load_mma_q, pipeline_load_mma_q_producer_state,\n          pipeline_load_mma_do, pipeline_load_mma_do_producer_state,\n          pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,\n          pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state\n      );\n\n    }\n    else if (role == WarpRole::Mma) {\n      warpgroup_reg_set<RegisterAllocation::kMma>();\n\n      tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);\n      __syncwarp();\n\n      mma(\n          blk_coord,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          shared_storage.tensors,\n          pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,\n          pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,\n          pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,\n          pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,\n          pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,\n          pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,\n          pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,\n          pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state\n      );\n\n    }\n    else if (role == WarpRole::Compute) {\n      warpgroup_reg_set<RegisterAllocation::kCompute>();\n\n      compute(\n          blk_coord,\n          blk_offset,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.epilogue,\n          shared_storage.tensors,\n          pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,\n          pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,\n          pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,\n          pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,\n          pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,\n          pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,\n          pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state\n      );\n\n      cutlass::arch::NamedBarrier(\n          kNumComputeWarps * NumThreadsPerWarp,\n          cutlass::arch::ReservedNamedBarriers::EpilogueBarrier\n      ).arrive_and_wait();\n\n      if (warp_idx % kNumComputeWarps == 0) {\n        uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;\n        tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);\n      }\n\n    }\n    else if (role == WarpRole::Reduce) {\n      warpgroup_reg_set<RegisterAllocation::kReduce>();\n\n      reduce(\n          blk_coord,\n          problem_shape,\n          iter_start,\n          iter_count,\n          params.mainloop,\n          params.mainloop_params,\n          shared_storage.tensors,\n          pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,\n          pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state\n      );\n\n      pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);\n    }\n    else {\n      warpgroup_reg_set<RegisterAllocation::kEmpty>();\n\n      /* no-op */\n\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\\n\");\n    }\n#endif\n  }\n\n  static dim3 get_block_shape() {\n    dim3 block(MaxThreadsPerBlock, 1, 1);\n    return block;\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    auto [Q, K, D, D_VO, HB] = params.problem_shape;\n    auto [H, B] = HB;\n    dim3 grid(ceil_div(K, TileShapeK{}), H, B);\n    return grid;\n  }\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cute/layout.hpp\"\n#include \"cutlass/arch/arch.h\"\n#include \"cutlass/kernel_hardware_info.h\"\n#include \"cutlass/pipeline/pipeline.hpp\"\n#include \"cute/arch/tmem_allocator_sm100.hpp\"\n\n#include <kerutils/kerutils.cuh> // for  KERUTILS_ENABLE_SM100A\n#include \"../kernel/fmha_options.hpp\"\n#include \"../kernel/fmha_tile_scheduler.hpp\"\n#include \"../kernel/fmha_causal_tile_scheduler.hpp\"\n#include \"../collective/fmha_fusion.hpp\"\n#include \"../collective/fmha_common.hpp\"\n\nnamespace cutlass::fmha::kernel {\n\nusing namespace cute;\nusing namespace cutlass::fmha::collective;\n\nstruct Sm100FmhaCtxKernelWarpspecializedSchedule {\n\n  enum class WarpRole {\n    Softmax0,\n    Softmax1,\n    Correction,\n    MMA,\n    Load,\n    Epilogue,\n    Empty\n  };\n\n  static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {\n    int wg_idx = warp_idx / 4;                        // warp_idx\n    if (wg_idx == 0) return WarpRole::Softmax0;       //   0 -  3\n    if (wg_idx == 1) return WarpRole::Softmax1;       //   4 -  7\n    if (wg_idx == 2) return WarpRole::Correction;     //   8 - 11\n    if (warp_idx == 12) return WarpRole::MMA;         //       12\n    if (warp_idx == 13) return WarpRole::Load;        //       13\n    if (warp_idx == 14) return WarpRole::Epilogue;    //       14\n    return WarpRole::Empty;                           //       15\n  }\n\n  static const int NumWarpsSoftmax = 4;\n  static const int NumWarpsCorrection = 4;\n  static const int NumWarpsEpilogue = 1;\n  static const int NumWarpsLoad = 1;\n\n  static const bool kDebugUsingPrintf = false;\n  static const int NumRegsSoftmax = 192;\n  static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);\n  static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);\n  static const int NumRegsEmpty = 24;\n  \n  static const int NumWarps = 16;\n  \n};\n\n\nstruct Sm100MlaFwdCtxKernelWarpspecializedSchedule {\n\n  enum class WarpRole {\n    Softmax0,\n    Softmax1,\n    Correction,\n    MMA,\n    Load,\n    Epilogue,\n    Empty\n  };\n\n  static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {\n    int wg_idx = warp_idx / 4;                        // warp_idx\n    if (wg_idx == 0) return WarpRole::Softmax0;       //   0 -  3\n    if (wg_idx == 1) return WarpRole::Softmax1;       //   4 -  7\n    if (wg_idx == 2) return WarpRole::Correction;     //   8 - 11\n    if (warp_idx == 12) return WarpRole::MMA;         //       12\n    if (warp_idx == 13) return WarpRole::Load;        //       13\n    if (warp_idx == 14) return WarpRole::Epilogue;    //       14\n    return WarpRole::Empty;                           //       15\n  }\n\n  static const int NumWarpsSoftmax = 4;\n  static const int NumWarpsCorrection = 4;\n  static const int NumWarpsEpilogue = 1;\n  static const int NumWarpsLoad = 1;\n\n  static const bool kDebugUsingPrintf = false;\n  static const int NumRegsSoftmax = 184;\n  static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);\n  static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0);\n  static const int NumRegsEmpty = 24;\n\n  static const int NumWarps = 16;\n\n};\n\ntemplate<\n  class ProblemShapeIn,\n  class CollectiveMainloop,\n  class CollectiveEpilogue,\n  class TileScheduler,\n  class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule\n>\nstruct Sm100FmhaFwdKernelTmaWarpspecialized {\n\n  using TileShape = typename CollectiveMainloop::TileShape;\n  using ProblemShape = ProblemShapeIn;\n\n  using WarpRole = typename KernelSchedule::WarpRole;\n\n  constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {\n    return KernelSchedule::warp_idx_to_WarpRole(warp_idx);\n  }\n\n  static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;\n  static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;\n  static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;\n  static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;\n  \n  static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);\n  static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);\n\n  static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;\n  static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;\n  static const int NumRegsOther = KernelSchedule::NumRegsOther;\n  static const int NumRegsEmpty = 24;\n\n  static const int NumWarps = KernelSchedule::NumWarps;\n\n  static constexpr bool IsMla = std::is_same_v<KernelSchedule, Sm100MlaFwdCtxKernelWarpspecializedSchedule>;\n\n  using ClusterShape = typename CollectiveMainloop::ClusterShape;\n\n  using TmemAllocator = cute::TMEM::Allocator1Sm;\n\n  struct SharedStorage {\n    using UnionType = union {\n      typename CollectiveMainloop::TensorStorage mainloop;\n      typename CollectiveEpilogue::TensorStorage epilogue;\n    };\n\n    using  StructType = struct {\n      typename CollectiveMainloop::TensorStorage mainloop;\n      typename CollectiveEpilogue::TensorStorage epilogue;\n    };\n\n    static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;\n    using MainloopEpilogueStorage = std::conditional_t<IsPersistent, \n                                                       std::conditional_t<IsMla, \n                                                                          std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,\n                                                                          StructType>,\n                                                       UnionType>;\n\n    MainloopEpilogueStorage mainloop_epilogue; \n\n    struct PipelineStorage {\n      alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;\n      alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;\n      alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;\n      alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;\n      alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;\n      alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;\n      alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;\n      alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;\n      alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;\n    } pipelines;\n\n    uint32_t tmem_base_ptr;\n  };\n\n  static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n  struct Arguments {\n    ProblemShape problem_shape;\n    typename CollectiveMainloop::Arguments mainloop;\n    typename CollectiveEpilogue::Arguments epilogue;\n    cutlass::KernelHardwareInfo hw_info;\n  };\n\n  struct Params {\n    ProblemShape problem_shape;\n    typename CollectiveMainloop::Params mainloop;\n    typename CollectiveEpilogue::Params epilogue;\n    typename TileScheduler::Params tile_scheduler;\n  };\n\n  static const int MinBlocksPerMultiprocessor = 1;\n  static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;\n  using ArchTag = cutlass::arch::Sm100;\n\n  static size_t get_workspace_size(Arguments const& args) { return 0; }\n  static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {\n    return cutlass::Status::kSuccess;\n  }\n\n  static bool can_implement(Arguments const& args) {\n    return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);\n  }\n\n  static dim3 get_grid_shape(Params const& params) {\n    return TileScheduler::get_grid_shape(params.tile_scheduler);\n  }\n\n  static dim3 get_block_shape() {\n    dim3 block(MaxThreadsPerBlock, 1, 1);\n    return block;\n  }\n\n  static Params to_underlying_arguments(Arguments const& args, void* workspace) {\n    return Params{\n        args.problem_shape,\n        CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),\n        CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),\n        TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{})\n    };\n  }\n\n  CUTLASS_DEVICE auto apply_batch(const Params &params, ProblemShape const& problem_shape, int batch_idx) {\n    return apply_variable_length(params.problem_shape, batch_idx);\n  }\n\n  CUTLASS_DEVICE void operator()(const Params &params, char* smem) {\n#if defined(KERUTILS_ENABLE_SM100A)\n\n    TileScheduler tile_scheduler{params.tile_scheduler};\n\n    int warp_idx = cutlass::canonical_warp_idx_sync();\n    auto role = warp_idx_to_WarpRole(warp_idx);\n    uint32_t lane_predicate = cute::elect_one_sync();\n\n    if (role == WarpRole::Load && lane_predicate) {\n      CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);\n    }\n\n    if (role == WarpRole::Epilogue && lane_predicate) {\n      CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);\n    }\n\n    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);\n\n    auto get_epilogue_storage = [&]() {\n      if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {\n        return reinterpret_cast<typename CollectiveEpilogue::TensorStorage *>(shared_storage.mainloop_epilogue.mainloop.smem_o.data());\n      } else {\n        return &shared_storage.mainloop_epilogue.epilogue;\n      }\n    };\n    typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage();\n\n\n    typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::MMA) {\n      pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;\n    }\n    pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ;\n    typename CollectiveMainloop::PipelineQ pipeline_load_q(\n      shared_storage.pipelines.load_q,\n      pipeline_load_q_params,\n      ClusterShape{},  cute::true_type{}, /*mask calc*/cute::false_type{});\n    \n    typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;\n    if (role == WarpRole::Load) {\n      pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::MMA) {\n      pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;\n    }\n    pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);\n    pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK;\n    typename CollectiveMainloop::PipelineKV pipeline_load_kv(\n      shared_storage.pipelines.load_kv,\n      pipeline_load_kv_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;\n    if (role == WarpRole::MMA) {\n      pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Softmax0) {\n      pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;\n    }\n    pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineS pipeline_mma_s0(\n      shared_storage.pipelines.mma_s0,\n      pipeline_mma_s0_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;\n    if (role == WarpRole::MMA) {\n      pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Softmax1) {\n      pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;\n    }\n    pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineS pipeline_mma_s1(\n      shared_storage.pipelines.mma_s1,\n      pipeline_mma_s1_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;\n    if (role == WarpRole::Softmax0) {\n      pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Correction) {\n      pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;\n    }\n    pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;\n    pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineC pipeline_s0_corr(\n      shared_storage.pipelines.s0_corr,\n      pipeline_s0_corr_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;\n    if (role == WarpRole::Softmax1) {\n      pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Correction) {\n      pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;\n    }\n    pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;\n    pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineC pipeline_s1_corr(\n      shared_storage.pipelines.s1_corr,\n      pipeline_s1_corr_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;\n    if (role == WarpRole::MMA) {\n      pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Correction) {\n      pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;\n    }\n    pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineO pipeline_mma_corr(\n      shared_storage.pipelines.mma_corr,\n      pipeline_mma_corr_params,\n      ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});\n\n    typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;\n    if (role == WarpRole::Correction) {\n      pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;\n    }\n    if (role == WarpRole::Epilogue) {\n      pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;\n    }\n    pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;\n    pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::PipelineE pipeline_corr_epi(\n      shared_storage.pipelines.corr_epi,\n      pipeline_corr_epi_params,\n      /*barrier init*/ cute::true_type{});\n\n    typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;\n    params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;\n    params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;\n    typename CollectiveMainloop::OrderBarrierSoftmax order_s01(\n      shared_storage.pipelines.order_s01, params_order_s01);\n\n    TmemAllocator tmem_allocator;\n\n    __syncthreads();\n\n    pipeline_load_q.init_masks(ClusterShape{});\n    pipeline_load_kv.init_masks(ClusterShape{});\n    pipeline_mma_s0.init_masks(ClusterShape{});\n    pipeline_mma_s1.init_masks(ClusterShape{});\n    pipeline_mma_corr.init_masks(ClusterShape{});\n\n    typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;\n    typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();\n\n    typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;\n    typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();\n\n    typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;\n    typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();\n\n    typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;\n    typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();\n\n    typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;\n    typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();\n\n    typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;\n    typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();\n\n    typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;\n    typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();\n\n    typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;\n    typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();\n\n    CollectiveMainloop mainloop;\n    CollectiveEpilogue epilogue{params.epilogue};\n\n    if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {\n      warpgroup_reg_set<NumRegsSoftmax>();\n\n      CUTLASS_PRAGMA_NO_UNROLL\n      for (; tile_scheduler.is_valid(); ++tile_scheduler) {\n        auto blk_coord = tile_scheduler.get_block_coord();\n\n        auto logical_problem_shape = apply_batch(params,\n            params.problem_shape, get<2,1>(blk_coord));\n\n        if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {\n          continue;\n        }\n\n        if (get<1>(logical_problem_shape) == 0) {\n          continue;\n        }\n\n        bool is_softmax_0 = role == WarpRole::Softmax0;\n\n        mainloop.softmax(\n           is_softmax_0 ? 0 : 1, blk_coord,\n           params.mainloop, logical_problem_shape,\n           is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,\n           is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,\n           is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,\n           is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,\n           order_s01\n         );\n\n       }\n    }\n    else if (role == WarpRole::Correction) {\n      cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();\n\n      bool has_valid = false;\n\n      CUTLASS_PRAGMA_NO_UNROLL\n      for (; tile_scheduler.is_valid(); ++tile_scheduler) {\n        auto blk_coord = tile_scheduler.get_block_coord();\n\n        auto logical_problem_shape = apply_batch(params,\n            params.problem_shape, get<2,1>(blk_coord));\n\n        if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {\n          continue;\n        }\n\n        has_valid = true;\n\n        if (get<1>(logical_problem_shape) == 0) {\n          mainloop.correction_empty(\n            blk_coord,\n            params.mainloop, logical_problem_shape,\n            params.problem_shape,\n            epilogue_storage,\n            pipeline_corr_epi, pipeline_corr_epi_producer_state,\n            epilogue\n          );\n          continue;\n        }\n\n        mainloop.correction(\n          blk_coord,\n          params.mainloop, logical_problem_shape,\n          params.problem_shape,\n          epilogue_storage,\n          pipeline_s0_corr, pipeline_s0_corr_consumer_state,\n          pipeline_s1_corr, pipeline_s1_corr_consumer_state,\n          pipeline_mma_corr, pipeline_mma_corr_consumer_state,\n          pipeline_corr_epi, pipeline_corr_epi_producer_state,\n          epilogue\n        );\n\n      }\n\n      if constexpr (NumWarpsEpilogue == 0) {\n        static_assert(NumWarpsCorrection == 1);\n\n        if (has_valid) {\n          uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;\n          tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);\n        }\n      }\n\n    }\n    else if (role == WarpRole::MMA) {\n      warpgroup_reg_set<NumRegsOther>();\n\n      bool allocated = false;\n\n      CUTLASS_PRAGMA_NO_UNROLL\n      for (; tile_scheduler.is_valid(); ++tile_scheduler) {\n        auto blk_coord = tile_scheduler.get_block_coord();\n\n        auto logical_problem_shape = apply_batch(params,\n            params.problem_shape, get<2,1>(blk_coord));\n\n        if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {\n          continue;\n        }\n\n        if (!allocated) {\n          tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);\n          __syncwarp();\n          allocated = true;\n        }\n\n        if (get<1>(logical_problem_shape) == 0) {\n          continue;\n        }\n\n        mainloop.mma(\n          blk_coord,\n          params.mainloop, logical_problem_shape,\n          shared_storage.mainloop_epilogue.mainloop,\n          pipeline_load_q, pipeline_load_q_consumer_state,\n          pipeline_load_kv, pipeline_load_kv_consumer_state,\n          pipeline_mma_s0, pipeline_mma_s0_producer_state,\n          pipeline_mma_s1, pipeline_mma_s1_producer_state,\n          pipeline_mma_corr, pipeline_mma_corr_producer_state\n        );\n\n      }\n    }\n    else if (role == WarpRole::Load) {\n      warpgroup_reg_set<NumRegsOther>();\n\n      if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {\n        cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, \n                                      cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n      }\n\n      CUTLASS_PRAGMA_NO_UNROLL\n      for (; tile_scheduler.is_valid(); ++tile_scheduler) {\n        auto blk_coord = tile_scheduler.get_block_coord();\n\n        auto logical_problem_shape = apply_batch(params,\n            params.problem_shape, get<2,1>(blk_coord));\n\n        if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {\n          continue;\n        }\n\n        if (get<1>(logical_problem_shape) == 0) {\n          continue;\n        }\n\n        mainloop.load(\n          blk_coord, logical_problem_shape,\n          params.mainloop, params.problem_shape,\n          shared_storage.mainloop_epilogue.mainloop,\n          pipeline_load_q, pipeline_load_q_producer_state,\n          pipeline_load_kv, pipeline_load_kv_producer_state\n        );\n\n      }\n    }\n    else if (role == WarpRole::Epilogue) {\n      warpgroup_reg_set<NumRegsOther>();\n\n      bool has_valid = false;\n\n      CUTLASS_PRAGMA_NO_UNROLL\n      for (; tile_scheduler.is_valid(); ++tile_scheduler) {\n        auto blk_coord = tile_scheduler.get_block_coord();\n\n        auto logical_problem_shape = apply_batch(params,\n            params.problem_shape, get<2,1>(blk_coord));\n\n        if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {\n          continue;\n        }\n\n        has_valid = true;\n\n        epilogue.store(\n          blk_coord, logical_problem_shape,\n          params.epilogue, params.problem_shape,\n          epilogue_storage,\n          pipeline_corr_epi, pipeline_corr_epi_consumer_state\n        );\n\n      }\n\n      static_assert(NumWarpsEpilogue <= 1);\n      if constexpr (NumWarpsEpilogue == 1) {\n        if(has_valid) {\n          uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;\n          tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);\n        }\n      }\n\n    }\n    else if (role == WarpRole::Empty) {\n      warpgroup_reg_set<NumRegsEmpty>();\n\n      /* no-op, donate regs and exit */\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\\n\");\n    }\n#endif\n  }\n\n};\n\n}  // namespace cutlass::fmha::kernel\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/common_subroutine.h",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\nnamespace sm100 {\n\n/*\nLoad K/V indices from global memory, and generate validity mask\nEach thread loads 8 indices\nShould be called by lanes 0 ~ (BLOCK_TOPK/8)\n*/\nCUTE_DEVICE\nchar load_indices_and_generate_mask(\n    int lane_idx,\n    int* gIndices,\n    int s_kv,\n    int abs_pos_start,\n    int topk_length\n) {\n    int indices[8];\n    KU_LDG_256(\n        gIndices + lane_idx*8, \n        indices,\n        \".nc\", \n        \"no_allocate\", \n        \"evict_normal\", \n        \"256B\"\n    );\n    auto is_valid = [&](int rel_pos_in_lane, int index) -> char {\n        int abs_pos = abs_pos_start + lane_idx*8 + rel_pos_in_lane;\n        return index >= 0 && index < s_kv && abs_pos < topk_length;\n    };\n    char is_ks_valid_mask = \\\n        is_valid(7, indices[7]) << 7 | \n        is_valid(6, indices[6]) << 6 | \n        is_valid(5, indices[5]) << 5 |\n        is_valid(4, indices[4]) << 4 |\n        is_valid(3, indices[3]) << 3 |\n        is_valid(2, indices[2]) << 2 |\n        is_valid(1, indices[1]) << 1 |\n        is_valid(0, indices[0]) << 0;\n    return is_ks_valid_mask;\n}\n\n\n/*\nGet P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary\n\nInitially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout:\n\n        N       N    ---   (topk)\n    +-------+-------+\n    |       |       |\n32  | Warp0 | Warp2 |\n    |       |       |\n    +-------+-------+\n    |       |       |\n32  | Warp1 | Warp3 |\n    |       |       |\n    +-------+-------+\n|\n(head)\n\nwhere N = NUM_ELEMS_PER_THREAD\n*/\ntemplate<\n    int NUM_ELEMS_PER_THREAD,\n    int TMEM_COL_START,\n    int BARRIER_WARP02_SYNC_ID,\n    int BARRIER_WARP13_SYNC_ID,\n    bool STORE_BACK_P\n>\nCUTE_DEVICE\nvoid retrieve_mask_and_reduce_p(\n    char* k_validness_base,\n    int local_warp_idx,\n    int lane_idx,\n    auto slot_bar_P_empty_arrival,\n    float p_exchange_buf[4][32*NUM_ELEMS_PER_THREAD],\n    float p[NUM_ELEMS_PER_THREAD]\n) {\n    using namespace cute;\n    using cutlass::arch::NamedBarrier;\n    static_assert(BARRIER_WARP13_SYNC_ID == BARRIER_WARP02_SYNC_ID+1);\n\n    float p_peer[NUM_ELEMS_PER_THREAD];\n    if (local_warp_idx < 2) {\n        ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p);\n        ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p_peer);\n    } else {\n        ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p_peer);\n        ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p);\n    }\n    cutlass::arch::fence_view_async_tmem_load();\n    ku::tcgen05_before_thread_sync();\n    slot_bar_P_empty_arrival();\n\n    // Mask invalid tokens\n    // We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load\n    static_assert(NUM_ELEMS_PER_THREAD == 32);\n    uint32_t is_k_valid = *(uint32_t*)(k_validness_base + (local_warp_idx>=2?NUM_ELEMS_PER_THREAD/8:0));\n    CUTE_UNROLL\n    for (int i = 0; i < NUM_ELEMS_PER_THREAD; i += 1) {\n        if (!(is_k_valid >> i & 1))\n            p[i] = -CUDART_INF_F;\n    }\n\n    // Reduce P within the cluster\n    {\n        // Store\n        // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {\n            ku::st_shared(&p_exchange_buf[local_warp_idx^2][i*32*4 + lane_idx*4], *(float4*)(p_peer + i*4));\n        }\n        NamedBarrier::arrive_and_wait(64, BARRIER_WARP02_SYNC_ID + (local_warp_idx&1));\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {\n            float2 t[2];\n            *(float4*)t = *(float4*)(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4]);\n            float2* cur_p = (float2*)(p + i*4);\n            cur_p[0] = ku::float2_add(cur_p[0], t[0]);\n            cur_p[1] = ku::float2_add(cur_p[1], t[1]);\n        }\n    }\n\n    if constexpr (STORE_BACK_P) {\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {\n            ku::st_shared(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4], *(float4*)(p+i*4));\n        }\n    }\n}\n\n/*\nRescale O in Tensor Memory.\n\nO should occupy 128 rows x (D_V/2) columns in Tensor Memory.\n*/\ntemplate<\n    int D_V,\n    int CHUNK_SIZE,\n    int TMEM_COL_START\n>\nCUTE_DEVICE\nvoid rescale_O(\n    float scale_factor\n) {\n    float2 scale_factor_float2 = {scale_factor, scale_factor};\n    float2 o[CHUNK_SIZE/2];\n\n    CUTE_UNROLL\n    for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {\n        // Load O\n        ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);\n        cutlass::arch::fence_view_async_tmem_load();\n\n        // Mult\n        for (int i = 0; i < CHUNK_SIZE/2; ++i) {\n            o[i] = ku::float2_mul(o[i], scale_factor_float2);\n        }\n\n        // Store O\n        ku::tmem_st_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);\n        cutlass::arch::fence_view_async_tmem_store();\n    }\n}\n\ntemplate<int NUM_ELEMS_PER_THREAD>\nCUTE_DEVICE\nfloat get_max(\n    float p[NUM_ELEMS_PER_THREAD]\n) {\n    float local_max = -CUDART_INF_F;\n    CUTE_UNROLL\n    for (int i = 0; i < NUM_ELEMS_PER_THREAD; ++i) {\n        local_max = max(local_max, p[i]);\n    }\n    return local_max;\n}\n\n/*\nCalculate s := exp2f(p*scale - new_max) and its sum\n*/\ntemplate<int NUM_ELEMS_PER_THREAD>\nCUTE_DEVICE\nfloat get_s_from_p(\n    nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2],\n    float p[NUM_ELEMS_PER_THREAD],\n    float scale,\n    float new_max\n) {\n    float2 cur_sum = float2 {0.0f, 0.0f};\n    float2 neg_new_max_float2 = float2 {-new_max, -new_max};\n    float2 scale_float2 = float2 {scale, scale};\n    CUTE_UNROLL\n    for (int i = 0; i < NUM_ELEMS_PER_THREAD/2; i += 1) {\n        float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale_float2, neg_new_max_float2);\n        d.x = exp2f(d.x);\n        d.y = exp2f(d.y);\n        cur_sum = ku::float2_add(cur_sum, d);\n        s[i] = __float22bfloat162_rn(d);\n    }\n    return cur_sum.x + cur_sum.y;\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head128/config.h",
    "content": "#pragma once\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"params.h\"\n#include \"defines.h\"\n\nnamespace sm100::fwd::head128 {\n\nusing namespace cute;\n\ntemplate<\n    typename Shape_Q, typename TMA_Q,\n    typename Shape_O, typename TMA_O\n>\nstruct TmaParams {\n    Shape_Q shape_Q; TMA_Q tma_Q;\n    Shape_O shape_O; TMA_O tma_O;\n    CUtensorMap tensor_map_kv;\n};\n\nstruct float2x2 {\n    float2 lo, hi;\n};\n\ntemplate<int D_QK>\nstruct KernelTemplate {\n\nstatic constexpr int D_Q = D_QK;\nstatic constexpr int D_K = D_QK;\nstatic constexpr int D_V = 512;\nstatic constexpr float MAX_INIT_VAL = -1e30;    // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan\n\nstatic constexpr int B_H = 128;    // For 2 CTAs\nstatic constexpr int B_TOPK = 128; // For 2 CTAs\nstatic constexpr int NUM_BUFS = 2;\nstatic constexpr int NUM_THREADS = 256 + 128 + 128; // 128 scale & exp threads, 128x2 TMA threads, 32 UTCMMA threads\n\n\nstatic constexpr int D_tQ = 384, NUM_tQ_TILES = D_tQ / 64;\nstatic constexpr int D_sQ = D_QK-D_tQ, NUM_sQ_TILES = D_sQ / 64;\nstatic_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q);\n\n// Tensor memory columns\nstruct tmem_cols {\n    //   0 ~ 256: output\n    // 256 ~ 320: P\n    // 320 ~ 512: Q[D_QK-D_tQ:]\n    static constexpr int o = 0;\n    static constexpr int p = 256;\n    static constexpr int q = 512 - D_tQ/2;\n    static_assert(p+64 <= q);\n};\n\ntemplate<int NUM_TILES>\nusing SmemLayoutQTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutOTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutO = SmemLayoutOTiles<8>;\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_TOPK/2>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutV = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_MN_SW128_Atom<bf16>{},\n    Shape<Int<256>, Int<B_TOPK>>{},\n    Step<_2, _1>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutSTiles = decltype(coalesce(tile_to_shape(\n\tUMMA::Layout_K_INTER_Atom<bf16>{},\n\tShape<Int<B_H/2>, Int<64*NUM_TILES>>{},\n\tStep<_1, _2>{}\n), Shape<_1, _1>{}));\n\nstruct SharedMemoryPlan {\n    union {\n        array_aligned<bf16, cosize_v<SmemLayoutQTiles<D_Q/64>>> q_full;\n        struct {\n            array_aligned<bf16, cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>>> sq;\n            array_aligned<bf16, cosize_v<SmemLayoutV>> v;\n            // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q\n            static_assert(cosize_v<SmemLayoutQTiles<D_Q/64>> <= cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>> + cosize_v<SmemLayoutV>);\n            array_aligned<bf16, cosize_v<SmemLayoutKTiles<D_K/64>>> k;\n        } s;\n        array_aligned<bf16, cosize_v<SmemLayoutO>> o;\n    } u;\n    array_aligned<bf16, cosize_v<SmemLayoutSTiles<2>>> s;\n    float p[(B_H/2)*B_TOPK];\n    char is_k_valid[NUM_BUFS][B_TOPK/8];\n    transac_bar_t bar_prologue_q, bar_prologue_utccp;\n    transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS];    // Pi = QKi^T done (i.e. Ki free)\n    transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS];    // O += SiVi done (i.e. Vi free)\n    transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS];\n    transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS];    // Vi is ready\n    transac_bar_t bar_p_free[NUM_BUFS];\n    transac_bar_t bar_so_ready[NUM_BUFS];   // S and O are ready\n    transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];\n    array_aligned<uint32_t, 1> tmem_start_addr;\n    float rowwise_max_buf[128], rowwise_li_buf[128];\n};\n\nusing TiledMMA_P_tQ = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}\n));\n\nusing TiledMMA_P_sQ = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}\n));\n\nusing TiledMMA_O = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{},\n    Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{}  // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]\n));\n\ntemplate<typename TmaParams>\nstatic __device__ void\nsparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params);\n\n};\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh",
    "content": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launch.hpp>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/arch/arch.h>\n#include <cutlass/cuda_host_adapter.hpp>\n\n#include \"params.h\"\n#include \"utils.h\"\n#include \"sm100/helpers.h\"\n\n#include \"config.h\"\n\nnamespace sm100::fwd::head128 {\n\nusing namespace cute;\n\nCUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) {\n    int32x8_t val;\n    asm volatile(\"ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\"\n        : \"=r\"(val.a0), \"=r\"(val.a1), \"=r\"(val.a2), \"=r\"(val.a3),\n          \"=r\"(val.a4), \"=r\"(val.a5), \"=r\"(val.a6), \"=r\"(val.a7)\n        : \"l\"(src_ptr)\n    );\n    return val;\n}\n\n/*\nPipeline Overview:\n\n| Copy |    MMA    |   Scale & Exp   |\n\nK0\nV0\n        P0 = QK0^T\nK1                  S0 = exp(P0)\n                    scale(O) w.r.t P0\n        P1 = QK1^T\nK2                  S1 = exp(P1)\n        O += S0V0\nV1                  scale(O) w.r.t P1\n        P2 = QK2^T\nK3                  S2 = exp(P2)\n        O += S1V1\nV2                  scale(O) w.r.t P2\n        P3 = QK3^T\nK4                  S3 = exp(P3)\n        O += S2V2\nV3                  scale(O) w.r.t P3\n\n...\n\n        O += S(n-3)V(n-3)\nV(n-2)              scale(O) w.r.t P(n-2)\n        P(n-1) = QK(n-1)^T\n                   S(n-1) = exp(P(n-1))\n        O += S(n-2)V(n-2)\nV(n-1)             scale(O) w.r.t P(n-1)\n        O += S(n-1)V(n-1)\n*/\n\ntemplate<int D_QK>\ntemplate<typename TmaParams>\n__device__ void\nKernelTemplate<D_QK>::sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params) {\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n    const int cta_idx = blockIdx.x % 2;\n    const int s_q_idx = blockIdx.x / 2;\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n    const int lane_idx = threadIdx.x % 32;\n    const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;\n    const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1);  // num_k_blocks always >= 1\n    const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);\n    const int idx_in_warpgroup = threadIdx.x % 128;\n\n    // Prefetch TMA descriptors\n    if (threadIdx.x == 0) {\n        cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv));\n    }\n\n    // Define shared tensors\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n    Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<D_Q/64>{});\n\n    int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]\n\n    // Allocate tmem tensors\n    TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{};\n    TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{};\n    TiledMMA tiled_mma_O = TiledMMA_O{};\n    Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<B_TOPK>>{});\n    Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A(\n        partition_shape_A(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<D_tQ>>{})\n    );\n    Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H/2>, Int<D_V>>{});\n    tP.data().get() = tmem_cols::p;\n    tQr.data().get() = tmem_cols::q;\n    tO.data().get() = tmem_cols::o;\n\n    if (warp_idx == 0) {\n        if (elect_one_sync()) {\n            // Initialize barriers\n            plan.bar_prologue_q.init(1);\n            plan.bar_prologue_utccp.init(1);\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_BUFS; ++i) {\n                plan.bar_qk_part_done[i].init(1);\n                plan.bar_qk_done[i].init(1);\n                plan.bar_sv_part_done[i].init(1);\n                plan.bar_sv_done[i].init(1);\n                plan.bar_k_part0_ready[i].init(1);\n                plan.bar_k_part1_ready[i].init(1);\n                plan.bar_v_part0_ready[i].init(1);\n                plan.bar_v_part1_ready[i].init(1);\n                plan.bar_p_free[i].init(128*2);\n                plan.bar_so_ready[i].init(128*2);\n                plan.bar_k_valid_ready[i].init(16);\n                plan.bar_k_valid_free[i].init(128);\n            }\n            fence_barrier_init();\n        }\n    }\n\n    cute::cluster_sync();   // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0\n\n    if (warp_idx == 0) {\n        if (elect_one_sync()) {\n            // Copy Q\n            Tensor gQ = flat_divide(\n                tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),\n                Tile<Int<B_H/2>>{}\n            )(_, cta_idx, _);\n            ku::launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST);\n        }\n\n        // Initialize TMEM\n        cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data());\n        TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);\n        cute::TMEM::Allocator2Sm().release_allocation_lock();\n    }\n\n    __syncthreads();    // Wait for TMEM allocation\n\n    if (warpgroup_idx == 0) {\n        cutlass::arch::warpgroup_reg_alloc<144>();\n        // Scale & Exp warps\n\n        // The following three numbers are \n        // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)\n        // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))\n        // - real_mi: real max logits, i.e. real_mi := max(Pi*scale)\n        // where Pi is the i-th row of P, P := QK^T\n        // mi and real_mi are always consistent within the two threads that\n        // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update\n        float mi = MAX_INIT_VAL;\n        float li = 0.0f;\n        float real_mi = -CUDART_INF_F;\n\n        const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};\n        uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8);\n        float* sP_base = plan.p + idx_in_warpgroup%64*4 + (idx_in_warpgroup/64)*((B_H/2)*(B_TOPK/2));\n\n        CUTE_NO_UNROLL\n        for (int k = 0; k < num_k_blocks; ++k) {\n            // Wait for P\n            plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);\n            ku::tcgen05_after_thread_sync();\n\n            // Load P\n            float2 p[(B_TOPK/2)/2];\n            ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::p, p);\n            cutlass::arch::fence_view_async_tmem_load();\n            ku::tcgen05_before_thread_sync();\n            plan.bar_p_free[k%NUM_BUFS].arrive(0u);\n\n            // Mask\n            plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1);\n            // The following code enables NVCC to use R2P instruction\n            // Although we perform 2x LDS.32 instructions here, don't worry, NVCC will\n            // convert them to one LDS.64 instruction. However, if we write LDS.64\n            // here, NVCC won't use R2P.\n            uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0));\n            uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4);\n            float* p_float = (float*)p;\n            CUTE_UNROLL\n            for (int i = 0; i < (B_TOPK/2)/2; i += 1) {\n                if (!(is_k_valid_lo >> i & 1))\n                    p_float[i] = -CUDART_INF_F;\n            }\n            CUTE_UNROLL\n            for (int i = 0; i < (B_TOPK/2)/2; i += 1) {\n                if (!(is_k_valid_hi >> i & 1))\n                    p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F;\n            }\n\n            // Get rowwise max of Pi\n            float cur_pi_max = -CUDART_INF_F;\n            CUTE_UNROLL\n            for (int i = 0; i < (B_TOPK/2); i += 1) {\n                cur_pi_max = max(cur_pi_max, p_float[i]);\n            }\n            cur_pi_max *= params.sm_scale_div_log2;\n\n            plan.bar_k_valid_free[k%NUM_BUFS].arrive();\n\n            NamedBarrier::arrive_and_wait(128, 0);  // Wait for rowwise_max_buf and sP to be ready\n            plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;\n            NamedBarrier::arrive_and_wait(128, 0);  // TODO Name these barriers\n            cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);\n            real_mi = max(real_mi, cur_pi_max);\n            bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);\n            // By this point:\n            // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)\n            // - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127\n\n\n            // Calc scale factor, and scale li\n            float new_max, scale_for_old;\n            if (!should_scale_o) {\n                // Don't scale O\n                scale_for_old = 1.0f;\n                new_max = mi;\n            } else {\n                new_max = max(cur_pi_max, mi);\n                scale_for_old = exp2f(mi - new_max);\n            }\n            mi = new_max;   // mi is still identical within each row\n            li *= scale_for_old;\n\n            // Calculate S\n            __nv_bfloat162 s[(B_TOPK/2)/2];\n            float2 neg_new_max = float2 {-new_max, -new_max};\n            CUTE_UNROLL\n            for (int i = 0; i < (B_TOPK/2)/2; i += 1) {\n                float2 d = ku::float2_fma(p[i], scale, neg_new_max);\n                d.x = exp2f(d.x);\n                d.y = exp2f(d.y);\n                li += d.x + d.y;    // NOTE: Theoretically we could use FFMA2 here but actually this is faster...\n                s[i] = __float22bfloat162_rn(d);\n            }\n\n            // Wait for last SV gemm, write S\n            if (k > 0) {\n                plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n            }\n            CUTE_UNROLL\n            for (int i = 0; i < B_TOPK/2/8; i += 1) {\n                sS_base[64*i] = *(uint128_t*)(s + i*4);\n            }\n\n            // Scale O\n            if (k > 0 && should_scale_o) {\n                float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; \n                // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);   // NOTE: We have waited for last SV gemm before\n                ku::tcgen05_after_thread_sync();\n\n                static constexpr int CHUNK_SIZE = 32;\n                float2 o[CHUNK_SIZE/2];\n                CUTE_UNROLL\n                for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {\n                    // Load O\n                    ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n\n                    // Mult\n                    for (int i = 0; i < CHUNK_SIZE/2; ++i) {\n                        o[i] = ku::float2_mul(o[i], scale_for_old_float2);\n                    }\n\n                    // Store O\n                    ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);\n                    cutlass::arch::fence_view_async_tmem_store();\n                }\n                ku::tcgen05_before_thread_sync();\n            }\n            \n            fence_view_async_shared();\n            plan.bar_so_ready[k%NUM_BUFS].arrive(0u);\n        }\n\n        // Epilogue\n\n        if (real_mi == -CUDART_INF_F) {\n            // real_mi == -CUDART_INF_F <=> No valid TopK indices\n            // We set li to 0 to fit the definition that li := exp(x[i] - mi)\n            li = 0.0f;\n            mi = -CUDART_INF_F;\n        }\n        \n        // Exchange li\n        plan.rowwise_li_buf[idx_in_warpgroup] = li;\n        NamedBarrier::arrive_and_wait(128, 0);\n        li += plan.rowwise_li_buf[idx_in_warpgroup^64];\n\n        // Store mi and li\n        if (idx_in_warpgroup < 64) {\n            int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup;\n            float cur_lse = logf(li) + mi*CUDART_LN2_F;\n            cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;\n            params.max_logits[global_index] = real_mi*CUDART_LN2_F;\n            params.lse[global_index] = cur_lse;\n        }\n\n        // Wait for the last GEMM\n        plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);\n        ku::tcgen05_after_thread_sync();\n\n        // Store O\n        float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + cta_idx*B_H/2 + (idx_in_warpgroup%64))*CUDART_L2E_F;\n        float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));\n        Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});\n        constexpr int B_EPI = 64;\n        Tensor tma_gO = flat_divide(\n            tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),\n            Shape<Int<B_H/2>, Int<B_EPI>>{}\n        )(_, _, cta_idx, _);\n        Tensor sO_divided = flat_divide(\n            sO,\n            Shape<Int<B_H/2>, Int<B_EPI>>{}\n        )(_, _, _0{}, _);\n        auto thr_tma = tma_params.tma_O.get_slice(_0{});\n\n        float2 o[B_EPI/2];\n        bool have_valid_indices = __any_sync(0xffffffff, li != 0);  // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld\n        if (!have_valid_indices) {\n            // If there are no valid indices, we set o[i] to 0 and don't load from TMEM\n            CUTE_UNROLL\n            for (int i = 0; i < B_EPI/2; ++i)\n                o[i].x = o[i].y = 0.0f;\n            output_scale = 1.0f;\n        }\n\n        float2 output_scale_float2 = make_float2(output_scale, output_scale);\n\n        CUTE_UNROLL\n        for (int k = 0; k < (D_V/2)/B_EPI; ++k) {\n            // Load O from tO\n            if (have_valid_indices) {\n                ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::o + k*B_EPI, o);\n                cutlass::arch::fence_view_async_tmem_load();\n            }\n\n            // Convert and store\n            CUTE_UNROLL\n            for (int i = 0; i < B_EPI/8; ++i) {\n                __nv_bfloat162 o_bf16[4];\n                CUTE_UNROLL\n                for (int j = 0; j < 4; ++j) {\n                    float2 d = ku::float2_mul(o[i*4+j], output_scale_float2);\n                    o_bf16[j] = __float22bfloat162_rn(d);\n                }\n                int smem_row = idx_in_warpgroup % 64;\n                int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8;\n                *(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16);\n            }\n\n            // Sync\n            fence_view_async_shared();\n            NamedBarrier::arrive_and_wait(128, 0);\n            \n            if (warp_idx == 0 && elect_one_sync()) {\n                cute::copy(\n                    tma_params.tma_O,\n                    thr_tma.partition_S(sO_divided(_, _, k)),\n                    thr_tma.partition_D(tma_gO(_, _, k))\n                );\n            }\n            if (warp_idx == 1 && elect_one_sync()) {\n                int k2 = k + (D_V/B_EPI/2);\n                cute::copy(\n                    tma_params.tma_O,\n                    thr_tma.partition_S(sO_divided(_, _, k2)),\n                    thr_tma.partition_D(tma_gO(_, _, k2))\n                );\n            }\n        }\n\n        if (warp_idx == 0) {\n            cute::TMEM::Allocator2Sm().free(0, 512);\n        }\n    } else if (warpgroup_idx == 1) {\n        // Producer warp for K\n        cutlass::arch::warpgroup_reg_dealloc<96>();\n        int warp_idx = cutlass::canonical_warp_idx_sync() - 4;\n        constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS;\n        if (elect_one_sync()) {\n            bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64;\n\n            CUTE_NO_UNROLL\n            for (int k = 0; k < num_k_blocks; ++k) {\n                int4 indices[NUM_LOCAL_ROWS_PER_WARP];\n                int max_indices = -1, min_indices = params.s_kv;\n                CUTE_UNROLL\n                for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {\n                    indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx);\n                    max_indices = max(max_indices, int4_max(indices[local_row]));\n                    min_indices = min(min_indices, int4_min(indices[local_row]));\n                }\n                bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;\n                bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;\n                    \n                auto load_part_ki = [&](transac_bar_t &bar, int local_col_start, int local_col_end) {\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {\n                        CUTE_UNROLL\n                        for (int local_col = local_col_start; local_col < local_col_end; ++local_col)\n                            ku::tma_gather4_cta_group_2<true>(\n                                &(tma_params.tensor_map_kv),\n                                bar,\n                                sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64),\n                                local_col*64,\n                                indices[local_row],\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                    }\n                };\n\n                int cur_buf = k%NUM_BUFS;\n                if (k > 0) {\n                    plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n                }\n                if (!should_skip_tma) {\n                    load_part_ki(plan.bar_k_part0_ready[cur_buf], 0, D_sQ/64);\n                } else {\n                    // NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid.\n                    // NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs.\n                    // NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases\n                    plan.bar_k_part0_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_sQ*sizeof(bf16), 1u);\n                }\n\n                if (k > 0) {\n                    plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n                }\n                if (!should_skip_tma) {\n                    load_part_ki(plan.bar_k_part1_ready[cur_buf], D_sQ/64, D_K/64);\n                } else {\n                    plan.bar_k_part1_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_tQ*sizeof(bf16), 1u);\n                }\n            }\n        }\n    } else if (warpgroup_idx == 2) {\n        // Producer warps for V\n        cutlass::arch::warpgroup_reg_dealloc<96>();\n        int warp_idx = cutlass::canonical_warp_idx_sync() - 8;\n        constexpr int NUM_WARPS = 4;\n\n        if (elect_one_sync()) {\n            // Wait for UTCCP\n            plan.bar_prologue_utccp.wait(0);\n\n            bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64;\n\n            CUTE_NO_UNROLL\n            for (int k = 0; k < num_k_blocks; ++k) {\n                auto load_part_vi = [&](transac_bar_t &bar, int local_row_start, int local_row_end) {\n                    CUTE_UNROLL\n                    for (int local_row = local_row_start; local_row < local_row_end; ++local_row) {\n                        int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);\n                        CUTE_UNROLL\n                        for (int local_col = 0; local_col < (D_V/2)/64; ++local_col)\n                            ku::tma_gather4_cta_group_2<true>(\n                                &(tma_params.tensor_map_kv),\n                                bar,\n                                sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),\n                                local_col*64 + (cta_idx?256:0),\n                                token_idxs,\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                    }\n                };\n\n                int cur_buf = k%NUM_BUFS;\n                if (k > 0) {\n                    plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n                }\n                load_part_vi(plan.bar_v_part0_ready[cur_buf], 0, (B_TOPK/2)/4/NUM_WARPS);\n\n                if (k > 0) {\n                    plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n                }\n                load_part_vi(plan.bar_v_part1_ready[cur_buf], (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS);\n            }\n        }\n    } else {\n        cutlass::arch::warpgroup_reg_alloc<168>();\n        \n        // MMA warp\n        if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) {\n            // S -> T copy for Q\n            UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                make_tensor(\n                    make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ),\n                    tile_to_shape(\n                        UMMA::Layout_K_SW128_Atom<bf16>{},\n                        Shape<Int<B_H/2>, Int<64>>{}\n                    )\n                )\n            );\n            plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));\n            plan.bar_prologue_q.wait(0);\n            ku::tcgen05_after_thread_sync();\n            CUTE_UNROLL\n            for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) {\n                // A tile is 64 rows * 64 cols (128B)\n                CUTE_UNROLL\n                for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) {\n                    // A subtile is 64 rows * 8 cols (128b)\n                    SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(\n                        sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16),   // Remember that 4 LSBs are not included\n                        tmem_cols::q + tile_idx*32 + subtile_idx*4\n                    );\n                }\n            }\n            ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2);\n\n            CUTE_NO_UNROLL\n            for (int k = 0; k < num_k_blocks+1; ++k) {\n                if (k < num_k_blocks) {\n                    // Pi = QKi^T\n                    int cur_buf = k%NUM_BUFS;\n                    Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles<NUM_sQ_TILES>{});\n                    Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles<NUM_sQ_TILES>{});\n                    Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles<NUM_tQ_TILES>{});\n\n                    // Wait for K (part0)\n                    plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16));\n                    plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1);\n                    if (k > 0) {\n                        plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n                    }\n                    ku::tcgen05_after_thread_sync();\n\n                    ku::utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true);\n                    ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2);\n\n                    // Wait for K (part1)\n                    plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16));\n                    plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1);\n                    ku::tcgen05_after_thread_sync();\n\n                    ku::utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false);\n                    ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2);\n                }\n                if (k > 0) {\n                    // O += S(i-1)V(i-1)\n                    int cur_buf = (k-1)%NUM_BUFS;\n\n                    Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{});\n                    Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{});\n                    Tensor sS_divided = flat_divide(sS, Tile<Int<B_H/2>, _64>{})(_, _, _0{}, _);    // (B_H/2, 64, 2)\n                    Tensor sV_divided = flat_divide(sV, Tile<Int<D_V/2>, _64>{})(_, _, _0{}, _);  // (D_V/2, 64, 2)\n\n                    // Wait for S(i-1) and O to be scaled\n                    plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);\n\n                    // Wait for V (part0), and issue O += sS @ sV\n                    plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));\n                    plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);\n                    ku::tcgen05_after_thread_sync();\n\n                    ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1);\n                    ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2);\n\n                    // Wait for V (part1), and issue O += sS @ sV\n                    plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));\n                    plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);\n                    ku::tcgen05_after_thread_sync();\n                    ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false);\n                    ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2);\n                }\n            }\n        } else if (warp_idx == 13) {\n            // KV valid loading warp\n            static_assert(B_TOPK == 128);\n            if (lane_idx < 16) {\n                CUTE_NO_UNROLL\n                for (int k = 0; k < num_k_blocks; ++k) {\n                    int cur_buf = k%NUM_BUFS;\n                    int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8);\n                    auto is_valid = [&](int rel_pos_in_lane, int index) -> char {\n                        int abs_pos = k*B_TOPK + lane_idx*8 + rel_pos_in_lane;\n                        return index >= 0 && index < params.s_kv && abs_pos < topk_length;\n                    };\n                    char is_ks_valid_mask = \\\n                        is_valid(7, indices.a7) << 7 | \n                        is_valid(6, indices.a6) << 6 | \n                        is_valid(5, indices.a5) << 5 |\n                        is_valid(4, indices.a4) << 4 |\n                        is_valid(3, indices.a3) << 3 |\n                        is_valid(2, indices.a2) << 2 |\n                        is_valid(1, indices.a1) << 1 |\n                        is_valid(0, indices.a0) << 0;\n\n                    plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);\n                    plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask;\n                    plan.bar_k_valid_ready[cur_buf].arrive();\n                }\n            }\n        }\n    }\n\n\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\");\n    }\n#endif\n}\n\ntemplate<typename Kernel, typename TmaParams>\n__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)\nsparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {\n    Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);\n}\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {\n    static_assert(D_QK == 576 || D_QK == 512);\n    using Kernel = KernelTemplate<D_QK>;\n\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.topk % Kernel::B_TOPK == 0);   // To save some boundry checkings\n    KU_ASSERT(params.h_q == Kernel::B_H);  // To save some calculation\n    KU_ASSERT(params.d_qk == D_QK);\n\n    auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);\n    auto tma_Q = cute::make_tma_copy(\n        SM100_TMA_2SM_LOAD_NOSPLIT{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q),\n            make_layout(\n                shape_Q,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)\n            )\n        ),\n        (typename Kernel::template SmemLayoutQTiles<D_QK/64>){}\n    );\n\n    auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);\n    auto tma_O = cute::make_tma_copy(\n        SM90_TMA_STORE{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.out),\n            make_layout(\n                shape_O,\n                make_stride(params.d_v, _1{}, params.h_q*params.d_v)\n            )\n        ),\n        (typename Kernel::template SmemLayoutOTiles<1>){}\n    );\n\n    CUtensorMap tensor_map_kv;\n    {\n        uint64_t size[2] = {D_QK, (unsigned long)params.s_kv};\n        uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};\n        uint32_t box_size[2] = {64, 1};\n        uint32_t elem_stride[2] = {1, 1};\n        CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(\n            &tensor_map_kv,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            2,\n            params.kv,\n            size,\n            stride,\n            box_size,\n            elem_stride,\n            CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n            CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,\n            CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE\n        );\n        KU_ASSERT(res == CUresult::CUDA_SUCCESS);\n    }\n\n    TmaParams<\n        decltype(shape_Q), decltype(tma_Q),\n        decltype(shape_O), decltype(tma_O)\n    > tma_params = {\n        shape_Q, tma_Q,\n        shape_O, tma_O,\n        tensor_map_kv\n    };\n    auto kernel = &sparse_attn_fwd_kernel<Kernel, decltype(tma_params)>;\n\n    constexpr size_t smem_size = sizeof(typename Kernel::SharedMemoryPlan);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    cutlass::ClusterLaunchParams launch_params = {\n        dim3(2*params.s_q, 1, 1),\n        dim3(Kernel::NUM_THREADS, 1, 1),\n        dim3(2, 1, 1),\n        smem_size,\n        params.stream\n    };\n    KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(\n        launch_params, (void*)kernel, params, tma_params\n    ));\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head128/phase1.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head64/config.h",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n\nnamespace sm100::fwd::head64 {\n\nusing namespace cute;\n\ntemplate<\n    typename Shape_Q_NoPE, typename TMA_Q_NoPE,\n    typename Shape_Q_RoPE, typename TMA_Q_RoPE,\n    typename Shape_O, typename TMA_O\n>\nstruct TmaParams {\n    Shape_Q_NoPE shape_Q_nope; TMA_Q_NoPE tma_Q_nope;\n    Shape_Q_RoPE shape_Q_rope; TMA_Q_RoPE tma_Q_rope;\n    Shape_O shape_O; TMA_O tma_O;\n    CUtensorMap tensor_map_kv_nope;\n};\n\nstruct float2x2 {\n    float2 lo, hi;\n};\n\nconstexpr int D_Q = 576;\nconstexpr int D_K = 576;\nconstexpr int D_V = 512;\nconstexpr float MAX_INIT_VAL = -1e30;    // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan\n\nconstexpr int B_H = 64;\nconstexpr int B_TOPK = 64;\nconstexpr int NUM_BUFS = 3;\nconstexpr int NUM_THREADS = 128 + 128 + 128; // 128 scale & exp threads, 128 TMA threads, 32 UTCMMA threads\n\n\n// Tensor memory columns\nnamespace tmem_cols {\n    //   0 ~ 256: output\n    // 256 ~ 400: Q\n    // 400 ~ 464: P\n    constexpr int O = 0;\n    constexpr int Q = 256;\n    constexpr int Q_RoPE = 256 + 128;\n    constexpr int P = 400;\n}\n\nusing SmemLayoutQNoPE = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<D_V>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutQRoPE = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW64_Atom<bf16>{},\n    Shape<Int<B_H>, Int<D_Q-D_V>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutOTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutO = SmemLayoutOTiles<8>;\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutKNoPE = SmemLayoutKTiles<8>;\nusing SmemLayoutV = decltype(coalesce(\n    composition(\n        SmemLayoutKNoPE{},\n        Layout<Shape<Int<D_V>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}\n    )\n, Shape<_1, _1>{}));\n\nusing SmemLayoutKRoPE = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW64_Atom<bf16>{},\n    Shape<Int<B_TOPK>, Int<64>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutKNoPE_TiledMMA = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_TOPK*2>, Int<D_V/2>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));   // Re-view K-NoPE as B_TOPK*2 x D_V/2 for dual gemm\n\nusing SmemLayoutKRoPE_TiledMMA = decltype(coalesce(tile_to_shape(\n    UMMA::Layout_K_SW64_Atom<bf16>{},\n    Shape<Int<B_TOPK*2>, Int<64/2>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\nusing SmemLayoutS = decltype(coalesce(tile_to_shape(\n\tUMMA::Layout_K_INTER_Atom<bf16>{},\n\tShape<Int<B_H>, Int<B_TOPK>>{},\n\tStep<_1, _2>{}\n), Shape<_1, _1>{}));\n\n\nstruct SharedMemoryPlan {\n    union {\n        struct {\n            array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> _k_rope_pad;\n            array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> _k_pad[2];   // So that q_nope covers k[2]\n            array_aligned<bf16, cosize_v<SmemLayoutQNoPE>> q_nope;\n        } q_full;\n        struct {\n            array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> k_rope;\n            array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> k_nope[NUM_BUFS];\n        } k;\n        array_aligned<bf16, cosize_v<SmemLayoutO>> o;\n    } u;\n    float p_exchange_buf[4][32 * (B_TOPK/2)];\n    union {\n        bf16 s[B_H*B_TOPK];\n        array_aligned<bf16, cosize_v<SmemLayoutQRoPE>> q_rope;\n    } s_q_rope;\n    char is_k_valid[NUM_BUFS][B_TOPK/8];\n    transac_bar_t bar_prologue_q_nope, bar_prologue_q_rope, bar_prologue_utccp_nope, bar_prologue_utccp_rope;\n    transac_bar_t bar_qk_nope_done[NUM_BUFS], bar_qk_rope_done;    // Pi = QKi^T (the nope part) done\n    transac_bar_t bar_sv_done[NUM_BUFS];    // O += SiVi done (i.e. O, Si and Vi are free)\n    transac_bar_t bar_kv_nope_ready[NUM_BUFS][2], bar_kv_rope_ready;\n    transac_bar_t bar_p_free;\n    transac_bar_t bar_so_ready;   // S and O are ready\n    transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];\n    array_aligned<uint32_t, 1> tmem_start_addr;\n    float rowwise_max_buf[128], rowwise_li_buf[128];\n};\n\nusing TiledMMA_P = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, 128, UMMA::Major::K, UMMA::Major::K>{}  // Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: <TODO Fill link here>\n));\n\nusing TiledMMA_O = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}\n));\n\nenum NamedBarriers : int {\n    wg0_sync = 0,\n    wg0_warp02_sync = 1,\n    wg0_warp13_sync = 2,\n    pepi_sync = 3,\n};\n\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh",
    "content": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/arch/arch.h>\n#include <cutlass/cuda_host_adapter.hpp>\n\n#include <kerutils/kerutils.cuh>\n\n#include \"params.h\"\n#include \"utils.h\"\n#include \"sm100/helpers.h\"\n#include \"sm100/prefill/sparse/common_subroutine.h\"\n#include \"config.h\"\n\nnamespace sm100::fwd::head64 {\n\nusing namespace cute;\n\n/*\nPipeline Overview:\n\n| Copy |    MMA    |   Scale & Exp   |\n\nKV0\nKV1\nKV2\n        P0 = QK0^T\n                    S0 = exp(P0)\n                    scale(O) w.r.t P0\n        P1 = QK1^T\n                    S1 = exp(P1)\n        O += S0V0\nKV3                 scale(O) w.r.t P1\n        P2 = QK2^T\n                    S2 = exp(P2)\n        O += S1V1\nKV4                 scale(O) w.r.t P2\n        P3 = QK3^T\n                    S3 = exp(P3)\n        O += S2V2\nKV5                 scale(O) w.r.t P3\n\n...\n\n        O += S(n-3)V(n-3)\n                    scale(O) w.r.t P(n-2)\n        P(n-1) = QK(n-1)^T\n                   S(n-1) = exp(P(n-1))\n        O += S(n-2)V(n-2)\n                   scale(O) w.r.t P(n-1)\n        O += S(n-1)V(n-1)\n*/\n\nusing FwdMode = SparseAttnFwdMode;\n\ntemplate<bool HAVE_ROPE, typename TmaParams>\n__global__ void __launch_bounds__(NUM_THREADS, 1, 1)\nsparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n    // Grid shape: [s_q, 1, 1]\n\n    const int s_q_idx = blockIdx.x;\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n    const int lane_idx = threadIdx.x % 32;\n    const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);\n    const int idx_in_warpgroup = threadIdx.x % 128;\n    const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;\n    const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1);  // num_k_blocks always >= 1\n\n    // Define shared tensors\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n\n    int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]\n\n    // Allocate tmem tensors\n    TiledMMA tiled_mma_P = TiledMMA_P{};\n    TiledMMA tiled_mma_O = TiledMMA_O{};\n    // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)\n    Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});\n    Tensor tQ_nope_part0 = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n        partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})\n    );\n    Tensor tQ_nope_part1 = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n        partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})\n    );\n    Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n        partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<64/2>>{})\n    );\n    Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});\n    tP.data().get() = tmem_cols::P;\n    tQ_nope_part0.data().get() = tmem_cols::Q;\n    tQ_nope_part1.data().get() = tmem_cols::Q + 64;\n    tQ_rope.data().get() = tmem_cols::Q_RoPE;\n    tO.data().get() = tmem_cols::O;\n\n    if (warp_idx == 0) {\n        if (elect_one_sync()) {\n            // Copy Q\n            if constexpr (HAVE_ROPE) {\n                cute::prefetch_tma_descriptor(tma_params.tma_Q_rope.get_tma_descriptor());\n            }\n            cute::prefetch_tma_descriptor(tma_params.tma_Q_nope.get_tma_descriptor());\n\n            plan.bar_prologue_q_nope.init(1);\n            plan.bar_prologue_q_rope.init(1);\n            fence_barrier_init();\n            \n            if constexpr (HAVE_ROPE) {\n                Tensor gQ_rope = tma_params.tma_Q_rope.get_tma_tensor(tma_params.shape_Q_rope)(_, _, s_q_idx);\n                Tensor sQ_rope = make_tensor(make_smem_ptr(plan.s_q_rope.q_rope.data()), SmemLayoutQRoPE{});\n                ku::launch_tma_copy(tma_params.tma_Q_rope, gQ_rope, sQ_rope, plan.bar_prologue_q_rope, TMA::CacheHintSm90::EVICT_FIRST);\n            }\n\n            Tensor gQ_nope = tma_params.tma_Q_nope.get_tma_tensor(tma_params.shape_Q_nope)(_, _, s_q_idx);\n            Tensor sQ_nope = make_tensor(make_smem_ptr(plan.u.q_full.q_nope.data()), SmemLayoutQNoPE{});\n            ku::launch_tma_copy(tma_params.tma_Q_nope, gQ_nope, sQ_nope, plan.bar_prologue_q_nope, TMA::CacheHintSm90::EVICT_FIRST);\n\n            cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());\n            cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv_nope));\n            \n            // Initialize other barriers\n            plan.bar_prologue_utccp_rope.init(1);\n            plan.bar_prologue_utccp_nope.init(1);\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_BUFS; ++i) {\n                plan.bar_qk_nope_done[i].init(1);\n                plan.bar_sv_done[i].init(1);\n                plan.bar_kv_nope_ready[i][0].init(1);\n                plan.bar_kv_nope_ready[i][1].init(1);\n                plan.bar_k_valid_ready[i].init(B_TOPK/8);\n                plan.bar_k_valid_free[i].init(128);\n            }\n            plan.bar_p_free.init(128);\n            plan.bar_so_ready.init(128);\n            plan.bar_qk_rope_done.init(1);\n            plan.bar_kv_rope_ready.init(64);\n            fence_barrier_init();\n        }\n\n        // Initialize TMEM\n        cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());\n        TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);\n        cute::TMEM::Allocator1Sm().release_allocation_lock();\n    }\n\n    __syncthreads();\n\n    if (warpgroup_idx == 0) {\n        // Scale & Exp warps\n\n        // The following three numbers are \n        // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)\n        // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))\n        // - real_mi: real max logits, i.e. real_mi := max(Pi*scale)\n        // where Pi is the i-th row of P, P := QK^T\n        // mi and real_mi are always consistent within the two threads that\n        // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update\n        float mi = MAX_INIT_VAL;\n        float li = 0.0f;\n        float real_mi = -CUDART_INF_F;\n\n        bf16* sS_base = plan.s_q_rope.s + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);\n        static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;\n\n        CUTE_NO_UNROLL\n        for (int k = 0; k < num_k_blocks; ++k) {\n            // Wait for P\n            NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1));\n            plan.bar_qk_nope_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);\n            plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1);    // Put the barrier wait here for more code reordering space\n            ku::tcgen05_after_thread_sync();\n            \n            // Load P\n            float p[NUM_ELEMS_PER_THREAD];\n            retrieve_mask_and_reduce_p<\n                NUM_ELEMS_PER_THREAD,\n                tmem_cols::P,\n                NamedBarriers::wg0_warp02_sync,\n                NamedBarriers::wg0_warp13_sync,\n                false\n            >(\n                plan.is_k_valid[k%NUM_BUFS],\n                warp_idx, lane_idx, \n                [&]() {plan.bar_p_free.arrive();},\n                plan.p_exchange_buf,\n                p\n            );\n            plan.bar_k_valid_free[k%NUM_BUFS].arrive();\n            \n            // Get rowwise max of Pi\n            float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);\n            cur_pi_max *= params.sm_scale_div_log2;\n\n            plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;\n            NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n            cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);\n            real_mi = max(real_mi, cur_pi_max);\n            bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);\n            // By this point:\n            // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)\n            // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)\n\n\n            // Calc scale factor, and scale li\n            float new_max, scale_for_old;\n            if (!should_scale_o) {\n                // Don't scale O\n                scale_for_old = 1.0f;\n                new_max = mi;\n            } else {\n                new_max = max(cur_pi_max, mi);\n                scale_for_old = exp2f(mi - new_max);\n            }\n            mi = new_max;   // mi is still identical within each row\n\n            // Calculate S\n            nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];\n            float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);\n            li = fma(li, scale_for_old, cur_sum);\n\n            // Wait for last SV gemm, write S\n            if (k > 0) {\n                plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);\n            }\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; i += 1) {\n                *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);\n            }\n\n            // Scale O\n            if (k > 0 && should_scale_o) {\n                // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);   // NOTE We have waited for last SV gemm before\n                ku::tcgen05_after_thread_sync();\n                rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);\n                ku::tcgen05_before_thread_sync();\n            }\n            \n            fence_view_async_shared();\n            plan.bar_so_ready.arrive();\n        }\n\n        // Epilogue\n\n        if (real_mi == -CUDART_INF_F) {\n            // real_mi == -CUDART_INF_F <=> No valid TopK indices\n            // We set li to 0 to fit the definition that li := exp(x[i] - mi)\n            li = 0.0f;\n            mi = -CUDART_INF_F;\n        }\n        \n        // Exchange li\n        plan.rowwise_li_buf[idx_in_warpgroup] = li;\n        NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n        li += plan.rowwise_li_buf[idx_in_warpgroup^64];\n\n        // Store mi and li\n        if (idx_in_warpgroup < 64) {\n            int global_index = s_q_idx*params.h_q + idx_in_warpgroup;\n            float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));\n            cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;\n            params.max_logits[global_index] = real_mi*CUDART_LN2_F;\n            params.lse[global_index] = cur_lse;\n        }\n\n        // Wait for the last GEMM\n        plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);\n        ku::tcgen05_after_thread_sync();\n\n        // Fetch dO if necessary\n\n        // Store O\n        float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + (idx_in_warpgroup%64))*CUDART_L2E_F;\n        float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));\n        Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});\n        constexpr int B_EPI = 64;\n        Tensor tma_gO = flat_divide(\n            tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),\n            Shape<Int<B_H>, Int<B_EPI>>{}\n        )(_, _, _0{}, _);\n        Tensor sO_divided = flat_divide(\n            sO,\n            Shape<Int<B_H>, Int<B_EPI>>{}\n        )(_, _, _0{}, _);\n        auto thr_tma = tma_params.tma_O.get_slice(_0{});\n\n        float2 o[B_EPI/2];\n        bool have_valid_indices = __any_sync(0xffffffff, li != 0);  // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld\n        if (!have_valid_indices) {\n            // If there are no valid indices, we set o[i] to 0 and don't load from TMEM\n            CUTE_UNROLL\n            for (int i = 0; i < B_EPI/2; ++i)\n                o[i].x = o[i].y = 0.0f;\n            output_scale = 1.0f;\n        }\n\n        float2 output_scale_float2 = make_float2(output_scale, output_scale);\n\n        bf16* sO_addrs[8];\n        CUTE_UNROLL\n        for (int i = 0; i < B_EPI/8; ++i) {\n            sO_addrs[i] = &sO(idx_in_warpgroup%64, i*8);\n        }\n\n        CUTE_UNROLL\n        for (int c = 0; c < 2; ++c) {\n            // Each tile: 64 x 256\n            CUTE_UNROLL\n            for (int k = 0; k < (D_V/4)/B_EPI; ++k) {\n                // Load O from tO\n                if (have_valid_indices) {\n                    ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + c*128 + k*B_EPI, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n                }\n\n                // Convert and store\n                CUTE_UNROLL\n                for (int i = 0; i < B_EPI/8; ++i) {\n                    nv_bfloat162 o_bf16[4];\n                    CUTE_UNROLL\n                    for (int j = 0; j < 4; ++j) {\n                        o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);\n                        o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);\n                    }\n                    *(uint128_t*)(sO_addrs[i] + (c*(D_V/2) + (idx_in_warpgroup/64)*(D_V/4) + k*B_EPI)*64) = *(uint128_t*)(o_bf16);\n                }\n\n                // Sync\n                fence_view_async_shared();\n                NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);\n                \n                if (warp_idx == 0 && elect_one_sync()) {\n                    int epi_chunk_idx = c*(D_V/2/B_EPI) + k;\n                    cute::copy(\n                        tma_params.tma_O,\n                        thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),\n                        thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))\n                    );\n                }\n                if (warp_idx == 1 && elect_one_sync()) {\n                    int epi_chunk_idx = c*(D_V/2/B_EPI) + (D_V/B_EPI/4) + k;\n                    cute::copy(\n                        tma_params.tma_O,\n                        thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),\n                        thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))\n                    );\n                }\n            }\n        }\n\n\n        if (warp_idx == 0) {\n            cute::TMEM::Allocator1Sm().free(0, 512);\n        }\n    } else if (warpgroup_idx == 1) {\n        // Producer warp for KV\n        int warp_idx = cutlass::canonical_warp_idx_sync() - 4;\n        constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/4)/NUM_WARPS;\n        if (elect_one_sync()) {\n            CUTE_NO_UNROLL\n            for (int k = 0; k < num_k_blocks; ++k) {\n                int4 indices[NUM_LOCAL_ROWS_PER_WARP];\n                int max_indices = -1, min_indices = params.s_kv;\n                CUTE_UNROLL\n                for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {\n                    indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);\n                    max_indices = max(max_indices, int4_max(indices[local_row]));\n                    min_indices = min(min_indices, int4_min(indices[local_row]));\n                }\n                bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;\n                bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;\n\n                if (k == 2) {\n                    plan.bar_prologue_utccp_nope.wait(0);   // Since q_nope coincidences with k[2]\n                }\n\n                // Copy NoPE\n                int cur_buf = k%NUM_BUFS;\n                plan.bar_sv_done[cur_buf].wait((k/NUM_BUFS)&1^1);\n                bf16* sK_nope_base = plan.u.k.k_nope[cur_buf].data() + warp_idx*4*64;\n\n                auto load_kv_nope_part = [&](int part_idx) {\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {\n                        CUTE_UNROLL\n                        for (int local_col = part_idx*(D_V/2/64); local_col < (part_idx+1)*(D_V/2/64); ++local_col) {\n                            ku::tma_gather4(\n                                &(tma_params.tensor_map_kv_nope),\n                                plan.bar_kv_nope_ready[cur_buf][part_idx],\n                                sK_nope_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),\n                                local_col*64,\n                                indices[local_row],\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                        }\n                    }\n                };\n\n                if (!should_skip_tma) {\n                    load_kv_nope_part(0);\n                    load_kv_nope_part(1);\n                } else {\n                    // NOTE See head128/phase1.cuh for this TMA skipping technique\n                    CUTE_UNROLL\n                    for (int part_idx = 0; part_idx < 2; ++part_idx)\n                        plan.bar_kv_nope_ready[cur_buf][part_idx].complete_transaction(NUM_LOCAL_ROWS_PER_WARP*4*D_V/2*sizeof(bf16));\n                }\n            }\n        }\n    } else {\n        // MMA warp\n        if (warp_idx == 8 && elect_one_sync()) {\n            // S -> T copy for Q\n            UMMA::SmemDescriptor sQ_nope_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                make_tensor(\n                    make_smem_ptr(plan.u.q_full.q_nope.data()),\n                    tile_to_shape(\n                        UMMA::Layout_K_SW128_Atom<bf16>{},\n                        Shape<Int<B_H*2>, Int<64>>{}    // We use this shape for dual gemm (TODO Link)\n                    )\n                )\n            );\n            UMMA::SmemDescriptor sQ_rope_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                make_tensor(\n                    make_smem_ptr(plan.s_q_rope.q_rope.data()),\n                    tile_to_shape(\n                        UMMA::Layout_K_SW64_Atom<bf16>{},\n                        Shape<Int<B_H*2>, Int<32>>{}\n                    )\n                )\n            );\n            \n            if constexpr (HAVE_ROPE) {\n                // Copy the RoPE tile: 128 rows * 32 cols (64B) (in UTCCP's view), or 64 rows * 64 cols (in our view)\n                plan.bar_prologue_q_rope.arrive_and_expect_tx(B_H*(D_Q-D_V)*sizeof(bf16));\n                plan.bar_prologue_q_rope.wait(0);\n                ku::tcgen05_after_thread_sync();\n                CUTE_UNROLL\n                for (int subtile_idx = 0; subtile_idx < 2; ++subtile_idx) {\n                    // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)\n                    SM100_UTCCP_128dp256bit_1cta::copy(\n                        sQ_rope_desc + (subtile_idx*32) / 16,\n                        tmem_cols::Q_RoPE + subtile_idx*8\n                    );\n                }\n                ku::umma_arrive_noelect(plan.bar_prologue_utccp_rope);\n            }\n\n            plan.bar_prologue_q_nope.arrive_and_expect_tx(B_H*D_V*sizeof(bf16));\n            plan.bar_prologue_q_nope.wait(0);\n            ku::tcgen05_after_thread_sync();\n            CUTE_UNROLL\n            for (int tile_idx = 0; tile_idx < D_V/64/2; ++tile_idx) {\n                // A tile is 128 rows * 64 cols (128B) (in UTCCP's view), or 64 rows * 128 cols (in our view)\n                CUTE_UNROLL\n                for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {\n                    // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)\n                    SM100_UTCCP_128dp256bit_1cta::copy(\n                        sQ_nope_desc + (tile_idx*(B_H*128*2) + subtile_idx*32) / 16,   // Remember that 4 LSBs are not included\n                        tmem_cols::Q + tile_idx*32 + subtile_idx*8\n                    );\n                }\n            }\n            ku::umma_arrive_noelect(plan.bar_prologue_utccp_nope);\n\n            if constexpr (HAVE_ROPE) {\n                plan.bar_prologue_utccp_rope.wait(0);\n            }\n\n            CUTE_NO_UNROLL\n            for (int k = 0; k < num_k_blocks+1; ++k) {\n                if (k < num_k_blocks) {\n                    // Pi = QKi^T\n                    int cur_buf = k%NUM_BUFS;\n                    Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutKNoPE_TiledMMA{});\n                    Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE_TiledMMA{});\n\n                    plan.bar_p_free.wait(k&1^1);\n                    ku::tcgen05_after_thread_sync();\n                    \n                    // Wait for K (RoPE)\n                    // P = Q(rope) @ K(rope)^T\n                    if constexpr (HAVE_ROPE) {\n                        plan.bar_kv_rope_ready.wait(k&1);\n                        ku::tcgen05_after_thread_sync();\n                        ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);\n                        ku::umma_arrive_noelect(plan.bar_qk_rope_done);\n                    }\n\n                    // Wait for K (NoPE)\n                    if (k == 0) {\n                        plan.bar_prologue_utccp_nope.wait(0);\n                    }\n                    Tensor sK_nope_divided = flat_divide(sK_nope, Tile<Int<B_TOPK*2>, Int<D_V/4>>{})(_, _, _0{}, _);\n                    CUTE_UNROLL\n                    for (int kv_nope_part_idx = 0; kv_nope_part_idx < 2; ++kv_nope_part_idx) {\n                        plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].arrive_and_expect_tx(B_TOPK*D_V/2*sizeof(bf16));\n                        plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].wait((k/NUM_BUFS)&1);\n                        ku::tcgen05_after_thread_sync();\n\n                        // P += Q(nope) @ K(nope)^T\n                        bool clear_accum = (!HAVE_ROPE) && kv_nope_part_idx == 0;\n                        ku::utcmma_ts(tiled_mma_P, kv_nope_part_idx ? tQ_nope_part1 : tQ_nope_part0, sK_nope_divided(_, _, kv_nope_part_idx), tP, clear_accum);\n                    }\n                    ku::umma_arrive_noelect(plan.bar_qk_nope_done[cur_buf]);\n                }\n                if (k > 0) {\n                    // O += S(i-1)V(i-1)\n                    int cur_buf = (k-1)%NUM_BUFS;\n\n                    Tensor sS = make_tensor(make_smem_ptr(plan.s_q_rope.s), SmemLayoutS{});\n                    Tensor sV = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutV{});\n\n                    // Wait for S(i-1) and O to be scaled\n                    plan.bar_so_ready.wait((k-1)&1);\n                    ku::tcgen05_after_thread_sync();\n\n                    // O += sS @ sV\n                    ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == 1);\n                    ku::umma_arrive_noelect(plan.bar_sv_done[cur_buf]);\n                }\n            }\n        } else if (warp_idx == 9) {\n            // KV valid loading warp\n            if (lane_idx < B_TOPK/8) {\n                CUTE_NO_UNROLL\n                for (int k = 0; k < num_k_blocks; ++k) {\n                    char k_validness_mask = load_indices_and_generate_mask(\n                        lane_idx,\n                        gIndices + k*B_TOPK,\n                        params.s_kv,\n                        k*B_TOPK,\n                        topk_length\n                    );\n\n                    int cur_buf = k%NUM_BUFS;\n                    plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);\n                    plan.is_k_valid[cur_buf][lane_idx] = k_validness_mask;\n                    plan.bar_k_valid_ready[cur_buf].arrive();\n                }\n            }\n        } else if (warp_idx == 10 || warp_idx == 11) {\n            if constexpr (HAVE_ROPE) {\n                int thread_idx = threadIdx.x - 10*32;\n                constexpr int GROUP_SIZE = 8, NUM_GROUPS = 64/GROUP_SIZE, ROWS_PER_THREAD = B_TOPK/NUM_GROUPS;\n                int group_idx = thread_idx / GROUP_SIZE, idx_in_group = thread_idx % GROUP_SIZE;\n                Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE{});\n                bf16* sK_rope_base = &sK_rope(group_idx, idx_in_group*8);\n                CUTE_NO_UNROLL\n                for (int k = 0; k < num_k_blocks; ++k) {\n                    int indices[ROWS_PER_THREAD];\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row)\n                        indices[local_row] = __ldg(gIndices + k*B_TOPK + group_idx + local_row*NUM_GROUPS);\n                    plan.bar_qk_rope_done.wait(k&1^1);\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) {\n                        int index = indices[local_row];\n                        ku::cp_async_cacheglobal<ku::PrefetchSize::B128>(\n                            params.kv + (int64_t)index*params.stride_kv_s_kv + 512 + idx_in_group*8,\n                            sK_rope_base + local_row*NUM_GROUPS*32,\n                            index >= 0 && index < params.s_kv\n                        );  // NOTE Using cp.async instead of TMA is faster here\n                        // NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside)\n                    }\n                    cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)&(plan.bar_kv_rope_ready));\n                }\n            }\n        }\n    }\n\n\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\");\n    }\n#endif\n}\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.topk % B_TOPK == 0);   // To save some boundry checkings\n    KU_ASSERT(params.h_q == B_H);  // To save some calculation\n    KU_ASSERT(params.d_qk == D_QK);\n    static_assert(D_QK == 576 || D_QK == 512);\n\n    auto shape_Q_nope = make_shape(params.h_q, D_V, params.s_q);\n    auto tma_Q_nope = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q),\n            make_layout(\n                shape_Q_nope,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)\n            )\n        ),\n        SmemLayoutQNoPE{}\n    );\n\n    auto shape_Q_rope = make_shape(params.h_q, D_Q-D_V, params.s_q);\n    auto tma_Q_rope = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q + D_V),\n            make_layout(\n                shape_Q_rope,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)\n            )\n        ),\n        SmemLayoutQRoPE{}\n    );\n\n    auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);\n    auto tma_O = cute::make_tma_copy(\n        SM90_TMA_STORE{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.out),\n            make_layout(\n                shape_O,\n                make_stride(params.d_v, _1{}, params.h_q*params.d_v)\n            )\n        ),\n        SmemLayoutOTiles<1>{}\n    );\n\n\n    CUtensorMap tensor_map_kv_nope;\n    {\n        uint64_t size[2] = {D_V, (unsigned long)params.s_kv};\n        uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};\n        uint32_t box_size[2] = {64, 1};\n        uint32_t elem_stride[2] = {1, 1};\n        CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(\n            &tensor_map_kv_nope,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            2,\n            params.kv,\n            size,\n            stride,\n            box_size,\n            elem_stride,\n            CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n            CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,\n            CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE\n        );\n        KU_ASSERT(res == CUresult::CUDA_SUCCESS);\n    }\n\n    TmaParams<\n        decltype(shape_Q_nope), decltype(tma_Q_nope),\n        decltype(shape_Q_rope), decltype(tma_Q_rope),\n        decltype(shape_O), decltype(tma_O)\n    > tma_params = {\n        shape_Q_nope, tma_Q_nope,\n        shape_Q_rope, tma_Q_rope,\n        shape_O, tma_O,\n        tensor_map_kv_nope\n    };\n    auto kernel = &sparse_attn_fwd_kernel<D_QK == 576, decltype(tma_params)>;\n\n    constexpr size_t smem_size = sizeof(SharedMemoryPlan);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    kernel<<<params.s_q, NUM_THREADS, smem_size, params.stream>>>(params, tma_params);\n    KU_CHECK_KERNEL_LAUNCH();\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd/head64/phase1.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h",
    "content": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cutlass/float8.h>\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n#include \"params.h\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\nusing namespace cute;\n\ntemplate<SparseAttnFwdMode FWD_MODE, int D_QK>\nstruct KernelTemplate {\n\nusing ArgT = SparseFwdArgT<FWD_MODE>;\nstatic constexpr bool IS_DECODE = is_decode_v<FWD_MODE>;\nstatic constexpr bool IS_PREFILL = !IS_DECODE;\nusing fp8_e4m3 = cutlass::float_e4m3_t;\nusing fp8_e8m0 = __nv_fp8_e8m0;\n\nstruct TmaParamsForPrefill {\n    CUtensorMap tensor_map_q;\n    CUtensorMap tensor_map_kv;\n    CUtensorMap tensor_map_o;\n};\n\nstruct TmaParamsForDecode {\n    CUtensorMap tensor_map_q;\n    CUtensorMap tensor_map_o;\n    CUtensorMap tensor_map_o_accum;\n    CUtensorMap tensor_map_kv_nope;\n    CUtensorMap tensor_map_kv_rope;\n    CUtensorMap tensor_map_extra_kv_nope;   // Only available if extra_kv is enabled\n    CUtensorMap tensor_map_extra_kv_rope;\n};\n\nusing TmaParams = std::conditional_t<\n    IS_DECODE,\n    TmaParamsForDecode,\n    TmaParamsForPrefill\n>;\n\nstatic_assert(D_QK == 512);\n\nstatic constexpr int D_Q = D_QK;\nstatic constexpr int D_K = D_QK;\nstatic constexpr int D_V = 512;\nstatic constexpr float MAX_INIT_VAL = -1e30;    // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan\n\nstatic constexpr int H_Q = 128;    // For 2 CTAs\nstatic constexpr int B_TOPK = 64; // For 2 CTAs\nstatic constexpr int NUM_THREADS = 128*4;\nstatic constexpr int NUM_WORKER_THREADS = IS_PREFILL ? (128 + 4 + (B_TOPK/8) + 1 + 128)*2 + 1 : (128 + 128 + 1 + 32 + 2 + 128)*2;\n\n// For non-decode mode, we have 4 (half-)KV buffers\n// For decode mode, we have 3 (half-)KV buffers with two raw KV buffers\nstatic constexpr int NUM_K_BUFS = IS_DECODE ? 3 : 4;\nstatic constexpr int NUM_RAW_K_BUFS = IS_DECODE ? 2 : 0;\nstatic constexpr int NUM_INDEX_BUFS = IS_DECODE ? 4 : 4;\n\nstatic constexpr int D_NOPE = 448;\nstatic constexpr int D_ROPE = 64;\nstatic constexpr int TMA_K_STRIDE_FOR_DECODING = D_NOPE + 2*D_ROPE;\nstatic constexpr int NUM_SCALES_EACH_TOKEN = 8; // 7 scales + 1 padding\n\nstatic constexpr int B_EPI = 64;                // Epilogue block size for normal case (i.e. prefill or non-splitkv decoding)\nstatic constexpr int B_EPI_SPLITKV = 32;        // Epilogue block size for splitkv decoding\nstatic constexpr int NUM_EPI_SPLITKV_BUFS = 4;  // The number of epilogue buffers for splitkv decoding\nstatic_assert((H_Q/2)*D_Q*sizeof(bf16) >= NUM_EPI_SPLITKV_BUFS*(H_Q/2)*(B_EPI_SPLITKV*2)*sizeof(float));\n\n// Tensor memory columns\nstruct tmem_cols {\n    //   0 ~ 256: Output accumulator\n    // 256 ~ 384: Q\n    // 384 ~ 448: P\n    static constexpr int O = 0;\n    static constexpr int Q = 256;\n    static constexpr int P = 384;\n};\n\nstruct SharedMemoryPlan {\n    array_aligned<bf16, (H_Q/2)*D_Q> Q; // Will be output for epilogue\n    array_aligned<bf16, B_TOPK*(D_K/2)> K[NUM_K_BUFS];\n    array_aligned<fp8_e4m3, B_TOPK*(D_K/2)> K_raw[NUM_RAW_K_BUFS];\n    array_aligned<bf16, (H_Q/2)*B_TOPK> S;\n    float P_exchange[4][(H_Q/2/2)*(B_TOPK/2)];\n    float rowwise_max_buf[128], rowwise_li_buf[128];\n\n    CUTE_ALIGNAS(16) char is_k_valid[NUM_INDEX_BUFS][B_TOPK/8];\n    CUTE_ALIGNAS(16) int tma_coord[NUM_INDEX_BUFS][B_TOPK];\n    CUTE_ALIGNAS(16) fp8_e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN/2];\n    \n    transac_bar_t bar_sQ_full, bar_tQ_empty, bar_tQ_full;\n    transac_bar_t bar_tOut_full, bar_tOut_empty;\n    transac_bar_t bar_KV_full[NUM_K_BUFS], bar_KV_empty[NUM_K_BUFS];\n    transac_bar_t bar_P_empty;\n    transac_bar_t bar_QK_done, bar_SV_done;\n    transac_bar_t bar_S_O_full;\n    transac_bar_t bar_li_full, bar_li_empty;\n\n    // The following barriers are prefill-only\n    transac_bar_t bar_clc_full, bar_clc_empty;\n\n    // The following barriers are decode-only\n    transac_bar_t bar_raw_KV_full[NUM_RAW_K_BUFS], bar_raw_KV_empty[NUM_RAW_K_BUFS];\n    transac_bar_t bar_valid_coord_scales_full[NUM_INDEX_BUFS], bar_valid_coord_scales_empty[NUM_INDEX_BUFS];\n\n    ku::CLCResponseObj clc_response_obj;\n    array_aligned<uint32_t, 1> tmem_start_addr;\n};\n\nusing TiledMMA_P = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, H_Q, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}\n)); // *2 for dual gemm\n\nusing TiledMMA_O = decltype(make_tiled_mma(\n    SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, H_Q, 256, UMMA::Major::K, UMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{},\n    Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{}  // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]\n));\n\nstruct barrier_ids {\n    static constexpr int WG0_SYNC = 0;\n    static constexpr int WG2_SYNC = 1;\n    static constexpr int WG2_WARP02_SYNC = 2;\n    static constexpr int WG2_WARP13_SYNC = 3;\n};\n\nstatic __device__ void\nsparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params);\n\nstatic void run(const ArgT& params);\n\n};\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(const SparseAttnDecodeParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh",
    "content": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launch.hpp>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/arch/arch.h>\n\n#include \"params.h\"\n#include \"utils.h\"\n#include \"sm100/prefill/sparse/common_subroutine.h\"\n#include \"sm100/helpers.h\"\n\n#include \"config.h\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\nusing namespace cute;\nusing FwdMode = SparseAttnFwdMode;\n\ntemplate<FwdMode FWD_MODE, int D_QK>\n__device__ void\nKernelTemplate<FWD_MODE, D_QK>::sparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params) {\n#ifdef KERUTILS_ENABLE_SM100A\n    // Grid shape: [2*s_q, 1, 1] for prefilling, [2*s_q, num_sm_parts, 1] for decoding\n    // Cluster shape: [2, 1, 1]\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n    const int lane_idx = threadIdx.x % 32;\n    const int warpgroup_idx = cutlass::canonical_warp_group_idx();\n    const int idx_in_warpgroup = threadIdx.x % 128;\n    const int cta_idx = block_id_in_cluster().x;\n\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &smem = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n\n    if (warp_idx == 0 && elect_one_sync()) {\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_q);\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);\n        if constexpr (IS_DECODE) {\n            cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);\n            cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);\n        } else {\n            cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv);\n        }\n    } else if (warp_idx == 1 && elect_one_sync()) {\n        smem.bar_sQ_full.init(1);\n        smem.bar_tQ_empty.init(1);\n        smem.bar_tQ_full.init(1);\n        smem.bar_tOut_full.init(1);\n        smem.bar_tOut_empty.init(256);\n        smem.bar_P_empty.init(256);\n        smem.bar_QK_done.init(1);\n        smem.bar_SV_done.init(1);\n        smem.bar_S_O_full.init(256);\n        smem.bar_li_full.init(H_Q/2);\n        smem.bar_li_empty.init(128);\n        if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {\n            smem.bar_clc_full.init(1);\n            smem.bar_clc_empty.init(NUM_WORKER_THREADS);\n        }\n        fence_barrier_init();\n    } else if (warp_idx == 2) {\n        cute::TMEM::Allocator2Sm().allocate(512, smem.tmem_start_addr.data());\n        KU_TRAP_ONLY_DEVICE_ASSERT(smem.tmem_start_addr.data()[0] == 0);\n        cute::TMEM::Allocator2Sm().release_allocation_lock();\n    } else if (warp_idx == 3 && elect_one_sync()) {\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_K_BUFS; ++i) {\n            smem.bar_KV_full[i].init(IS_PREFILL ? 1 : (128/32)*2+1);\n            smem.bar_KV_empty[i].init(1);\n        }\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_INDEX_BUFS; ++i) {\n            smem.bar_valid_coord_scales_full[i].init(IS_PREFILL ? B_TOPK/8 : 32);\n            smem.bar_valid_coord_scales_empty[i].init(IS_PREFILL ? 128 : (128 + (cta_idx==1) + 2 + 128));\n        }\n        if constexpr (IS_DECODE) {\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_RAW_K_BUFS; ++i) {\n                smem.bar_raw_KV_full[i].init(1);\n                smem.bar_raw_KV_empty[i].init(128);\n            }\n        }\n        fence_barrier_init();\n    }\n\n    ku::barrier_cluster_arrive_relaxed();\n    ku::barrier_cluster_wait_acquire();\n\n    struct OuterloopArgs {\n        bool outer_loop_phase;\n        int batch_idx, s_q_idx;\n        int start_block_idx, end_block_idx;\n        int topk_length;\n\n        int extra_topk_length, num_orig_kv_blocks;  // extra-KV related\n        bool is_no_split; int n_split_idx;  // splitkv related\n    };\n\n    auto run_outer_loop = [&](auto loop_body) -> bool {\n        int outer_loop_phase = false;\n        if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {\n            int s_q_idx = blockIdx.x / 2;\n            DecodingSchedMeta sched_meta;\n            KU_LDG_256(\n                params.tile_scheduler_metadata_ptr + blockIdx.y,\n                &sched_meta,\n                \".nc\",\n                \"no_allocate\",\n                \"evict_normal\",\n                \"256B\"\n            );\n            if (sched_meta.begin_req_idx >= params.b) {\n                return 0;\n            }\n\n            #pragma unroll 1\n            for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {\n                int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;\n                int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);\n                int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;\n                int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK);    // % B_TOPK == 0\n                int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;\n                int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;\n                bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);\n                int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);\n\n                    // start_block_idx = 0;\n                    // end_block_idx = total_topk_padded / B_TOPK;\n                    // is_split = false;\n                    // n_split_idx = 0;\n\n                OuterloopArgs args = {\n                    (bool)outer_loop_phase,\n                    batch_idx, s_q_idx,\n                    start_block_idx, end_block_idx,\n                    topk_length,\n\n                    extra_topk_length, orig_topk_padded / B_TOPK,\n                    !is_split, n_split_idx\n                };\n\n                loop_body(args);\n                outer_loop_phase ^= 1;\n            }\n        } else {\n            // Prefill mode. Use CLC to allocate different s_q (for decoding, different batches + s_q) to different workers\n            ku::CLCResult next_job = {true, (int)blockIdx.x, IS_PREFILL ? 0 : (int)blockIdx.y, 0};\n            CUTE_NO_UNROLL\n            while (next_job.is_valid) {\n                int s_q_idx = next_job.x / 2;\n                int batch_idx = IS_PREFILL ? 0 : next_job.y;\n                int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + (IS_PREFILL?s_q_idx:batch_idx)) : params.topk;\n                \n                if constexpr (IS_PREFILL) {\n                    int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1);  // num_k_blocks always >= 1\n                    OuterloopArgs args = {\n                        (bool)outer_loop_phase,\n                        0, s_q_idx,\n                        0, num_k_blocks,\n                        topk_length\n                    };\n                    loop_body(args);\n                } else {\n                    int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);\n                    int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;\n                    int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK);    // % B_TOPK == 0\n\n                    OuterloopArgs args = {\n                        (bool)outer_loop_phase,\n                        batch_idx, s_q_idx,\n                        0, total_topk_padded / B_TOPK,\n                        topk_length,\n\n                        extra_topk_length, orig_topk_padded / B_TOPK,\n                        false, 0\n                    };\n                    loop_body(args);\n                }\n\n                smem.bar_clc_full.wait(outer_loop_phase);\n                next_job = ku::get_clc_query_response<true>(smem.clc_response_obj);\n                smem.bar_clc_empty.arrive(0u);\n\n                outer_loop_phase ^= 1;\n            }\n        }\n        return outer_loop_phase;\n    };\n\n    if (warpgroup_idx == 0) {\n        // Q fetching and O writing back warpgroup\n        cutlass::arch::warpgroup_reg_alloc<176>();\n\n        bf16* sO_addrs[B_EPI/8];\n        CUTE_UNROLL\n        for (int i = 0; i < B_EPI/8; ++i) {\n            Tensor sO = make_tensor(make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128>());\n            sO_addrs[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*(D_V/2) + i*8);\n        }\n\n        float* sO_accum_addrs[B_EPI_SPLITKV/4];\n        if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {\n            // If split-KV is enabled, we need to store back O in float32\n            // We view Q buffer (with shape 64 x 512, bf16) as 4 buffers with shape (H_Q/2) x (B_EPI_SPLITKV*2), float32\n            Tensor sO_accum = make_tensor(make_smem_ptr((float*)smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128, float>());\n            CUTE_UNROLL\n            for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {\n                sO_accum_addrs[i] = &sO_accum(idx_in_warpgroup%64, i*4) + (idx_in_warpgroup >= 64 ? (H_Q/2)*B_EPI_SPLITKV : 0);\n            }\n        }\n\n        auto perform_o_copy_out = [&](const OuterloopArgs &args, bool is_last_o) {\n            // outer_loop_phase is the loop phase corresponding to s_q_idx\n\n            // Get li (output_scale actually)\n            smem.bar_li_full.wait(args.outer_loop_phase);\n            float output_scale = smem.rowwise_li_buf[idx_in_warpgroup%64];\n            float2 output_scale_float2 = float2 {output_scale, output_scale};\n            smem.bar_li_empty.arrive();\n\n            // Retrieve and store O, and calculate delta := sum(O*dO, dim=-1) if FWD_MODE is Recompute\n            smem.bar_tOut_full.wait(args.outer_loop_phase);\n            if (is_last_o && elect_one_sync()) {\n                cudaTriggerProgrammaticLaunchCompletion();\n            }\n\n            if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {\n                CUTE_UNROLL\n                for (int k = 0; k < (D_V/2)/B_EPI; ++k) {\n                    float2 o[B_EPI/2];\n                    ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + k*B_EPI, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n                    if (k == (D_V/2)/B_EPI-1) {\n                        smem.bar_tOut_empty.arrive(0u);\n                    }\n                    CUTE_UNROLL\n                    for (int i = 0; i < B_EPI/8; ++i) {\n                        nv_bfloat162 o_bf16[4];\n                        CUTE_UNROLL\n                        for (int j = 0; j < 4; ++j) {\n                            o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);\n                            o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);\n                        }\n                        bf16* o_do_addr = sO_addrs[i] + k*B_EPI*(H_Q/2);\n                            if (k == 0 && i == 0) {\n                                smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o);    // Wait for sQ's availability\n                            }\n                        ku::st_shared(o_do_addr, *(__int128_t*)o_bf16);\n                    }\n                }\n\n                fence_view_async_shared();\n                NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);\n                if (warp_idx == 0 && elect_one_sync()) {\n                    SM90_TMA_STORE_5D::copy(\n                        &tma_params.tensor_map_o, \n                        smem.Q.data(),\n                        0, cta_idx*(H_Q/2), 0, args.s_q_idx, IS_DECODE ? args.batch_idx : 0\n                    );\n                    cute::tma_store_arrive();\n                }\n            } else {\n                CUTE_UNROLL\n                for (int k = 0; k < (D_V/2)/B_EPI_SPLITKV; ++k) {\n                    int cur_buf_idx = k % NUM_EPI_SPLITKV_BUFS;\n                    if (k == 0) {\n                        cute::tma_store_wait<0>();\n                    } else {\n                        cute::tma_store_wait<NUM_EPI_SPLITKV_BUFS-1>();\n                    }\n                    NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);\n\n                    float o[B_EPI_SPLITKV];\n                    ku::tmem_ld_32dp32bNx<B_EPI_SPLITKV>(tmem_cols::O + k*B_EPI_SPLITKV, o);\n                    cutlass::arch::fence_view_async_tmem_load();\n                    if (k == (D_V/2)/B_EPI_SPLITKV-1) {\n                        smem.bar_tOut_empty.arrive(0u);\n                    }\n                    CUTE_UNROLL\n                    for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {\n                        CUTE_UNROLL\n                        for (int j = 0; j < 4; j += 2) {\n                            *(float2*)(o + i*4 + j) = ku::float2_mul(float2 {o[i*4+j], o[i*4+j+1]}, output_scale_float2);\n                        }\n                        if (k == 0 && i == 0) {\n                            smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o);    // Wait for sQ's availability\n                        }\n                        ku::st_shared(\n                            sO_accum_addrs[i] + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2),\n                            *(__int128_t*)(o + i*4)\n                        );\n                    }\n\n                    fence_view_async_shared();\n                    NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);\n                    if constexpr (IS_DECODE) {  // Otherwise nvcc complains about `tma_params` doesn't have `tensor_map_o_accum`\n                        float* cur_buf_base = (float*)smem.Q.data() + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2);\n                        if (warp_idx == 0 && elect_one_sync()) {\n                            SM90_TMA_STORE_5D::copy(\n                                &tma_params.tensor_map_o_accum, \n                                cur_buf_base,\n                                0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32), args.s_q_idx, args.n_split_idx\n                            );\n                            cute::tma_store_arrive();\n                        } else if (warp_idx == 1 && elect_one_sync()) {\n                            SM90_TMA_STORE_5D::copy(\n                                &tma_params.tensor_map_o_accum, \n                                cur_buf_base + (H_Q/2)*B_EPI_SPLITKV,\n                                0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32) + (D_V/2)/32, args.s_q_idx, args.n_split_idx\n                            );\n                            cute::tma_store_arrive();\n                        }\n                    }\n                }\n            }\n        };\n\n        OuterloopArgs last_args;\n        last_args.batch_idx = -1;\n\n        bool final_outer_loop_phase = \\\n        run_outer_loop([&](const OuterloopArgs &args) {\n            // Copy Q for this round\n            if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {\n                cute::tma_store_wait<0>();\n                NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);  // Since we use two warps to issue TMA during FwdMode::DecodeWithSplitKV\n            }\n            if (warp_idx == 0 && elect_one_sync()) {\n                // Wait for sQ to become empty, and issue G -> S copy for Q\n                if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {\n                    cute::tma_store_wait<0>();  // This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document)\n                }\n                int stride_q_b_div_stride_q_s_q = 0;\n                if constexpr (IS_DECODE) {\n                    stride_q_b_div_stride_q_s_q = params.stride_q_b / params.stride_q_s_q;\n                }\n                SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(\n                    &tma_params.tensor_map_q,\n                    (uint64_t*)&smem.bar_sQ_full,\n                    (uint64_t)TMA::CacheHintSm90::EVICT_FIRST,\n                    smem.Q.data(),\n                    0, cta_idx*(H_Q/2), 0, 0, (IS_DECODE ? args.batch_idx*stride_q_b_div_stride_q_s_q : 0) + args.s_q_idx\n                );\n\n                // Wait for sQ to be ready, and issue S -> T copy for Q\n                if (cta_idx == 0) {\n                    smem.bar_sQ_full.arrive_and_expect_tx(H_Q*D_Q*sizeof(bf16));\n                    smem.bar_sQ_full.wait(args.outer_loop_phase);\n\n                    smem.bar_tQ_empty.wait(args.outer_loop_phase^1);\n                    ku::tcgen05_after_thread_sync();\n                    UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(\n                        make_tensor(\n                            make_smem_ptr(smem.Q.data()),\n                            ku::make_umma_canonical_k_major_layout<(H_Q/2)*2, 64, 128>()\n                        )\n                    );\n                    CUTE_UNROLL\n                    for (int tile_idx = 0; tile_idx < D_Q/64/2; ++tile_idx) {\n                        // A tile is 128 rows * 64 cols in UTCCP's view, or 64 rows * 128 cols in our view\n                        CUTE_UNROLL\n                        for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {\n                            // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)\n                            // NOTE Using `sQ_desc+((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4)` leads to IMA, doesn't know why\n                            UMMA::SmemDescriptor cur_sQ_desc = sQ_desc;\n                            cur_sQ_desc.lo += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);\n                            // uint64_t cur_sQ_desc = sQ_desc;\n                            // cur_sQ_desc += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);\n                            SM100_UTCCP_128dp256bit_2cta::copy(\n                                cur_sQ_desc,\n                                tmem_cols::Q + tile_idx*32 + subtile_idx*8\n                            );\n                        }\n                    }\n                    ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tQ_full, 1|2);\n                }\n            }\n\n            if (last_args.batch_idx != -1) {\n                perform_o_copy_out(last_args, false);\n            } else {\n                smem.bar_tQ_full.wait(args.outer_loop_phase);   // To prevent double arrive\n            }\n            last_args = args;\n        });\n        if (last_args.batch_idx != -1) {\n            cute::tma_store_wait<0>();\n            NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);\n            perform_o_copy_out(last_args, true);\n        }\n\n        if (warp_idx == 0) {\n            cute::TMEM::Allocator2Sm().free(0, 512);\n        }\n    } else if (warpgroup_idx == 1) {\n        // KV fetching threads for prefill, dequant threads for decoding\n        cutlass::arch::warpgroup_reg_dealloc<80>();\n        RingBufferState rs;\n\n        if constexpr (!IS_DECODE) {\n            const int warp_idx = cutlass::canonical_warp_idx();    // Using `warp_idx` without `__shfl_sync` is faster\n            if (elect_one_sync()) {\n                // KV fetching threads\n                run_outer_loop([&](const OuterloopArgs &args) {\n                    int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;\n                    int64_t cache_hint = ku::create_simple_cache_policy<ku::CacheHint::EVICT_LAST>();\n\n                    static constexpr int NUM_ROWS_PER_THREAD = B_TOPK / 4;\n\n                    CUTE_NO_UNROLL\n                    for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {\n                        auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();\n\n                        int cur_indices[NUM_ROWS_PER_THREAD];\n                        CUTE_UNROLL\n                        for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/8; local_row += 1) {\n                            int row = local_row*(4*8) + (warp_idx-4)*8;\n                            KU_LDG_256(\n                                gIndices + k*B_TOPK + row, \n                                cur_indices + local_row*8, \n                                \".nc\", \n                                \"no_allocate\", \n                                \"evict_first\", \n                                \"256B\"\n                            );\n                        }\n                        smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);\n\n                        CUTE_UNROLL\n                        for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/4; local_row += 1) {\n                            int row = (warp_idx-4)*8 + (local_row/2)*(4*8) + (local_row%2)*4;\n                            int4 indices = *(int4*)(cur_indices+local_row*4);\n                            static_assert(D_K == 512);\n                            CUTE_UNROLL\n                            for (int local_col = 0; local_col < (D_K/64)/2; ++local_col) {\n                                ku::tma_gather4_cta_group_2<true>(\n                                    &tma_params.tensor_map_kv,\n                                    smem.bar_KV_full[k_buf_idx],\n                                    smem.K[k_buf_idx].data() + row*64 + local_col*64*B_TOPK,\n                                    local_col*64 + cta_idx*(D_K/2),\n                                    indices,\n                                    cache_hint\n                                );\n                            }\n                        }\n                        rs.update();\n                    }\n                });\n            }\n                \n        } else {\n            // 8 threads per token\n            struct IsCTA0 {};\n            struct IsCTA1 {};\n\n            auto launch_dequant_wg = [&](auto cta_id_t) {\n                static constexpr bool IS_CTA1 = std::is_same<decltype(cta_id_t), IsCTA1>::value;\n                constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = (IS_CTA1 ? 256-64 : 256) / (GROUP_SIZE*8);\n                int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;\n                Tensor nope0 = make_tensor(make_smem_ptr(smem.K[0].data()), ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>());\n                bf16* nope0_base = &nope0(group_idx, idx_in_group*8);\n                fp8_e4m3* raw_nope0_base = smem.K_raw[0].data() + group_idx*(D_K/2) + idx_in_group*8;\n                run_outer_loop([&](const OuterloopArgs &args) {\n                    CUTE_NO_UNROLL\n                    for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                        auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();\n                        auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();\n                        auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                        fp8_e4m3* raw_nope_base = raw_nope0_base + raw_k_buf_idx * (B_TOPK*(D_K/2));\n                        auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {\n                            return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*(D_K/2) + local_col_idx*(GROUP_SIZE*8));\n                        };\n                        bf16* nope_base = nope0_base + k_buf_idx * (B_TOPK*(D_K/2));\n                        uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(nope_base);\n                        auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {\n                            asm volatile (\"st.weak.shared::cta.b128 [%0], %1;\\n\" \n                                : \n                                : \"r\"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), \"q\"(data)   // 2 for sizeof(bf16)\n                            );  // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS\n                        };\n\n                        smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);\n                        smem.bar_raw_KV_full[raw_k_buf_idx].wait(raw_k_bar_phase);\n\n                        CUTE_UNROLL\n                        for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {\n                            int row_idx = local_row_idx*NUM_GROUPS + group_idx;\n                            bf16 scales[4];\n                            fp8_e8m0 scales_e8m0[4];\n                            *(uint32_t*)scales_e8m0 = *(uint32_t*)(smem.scales[index_buf_idx][row_idx]);\n                            *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));\n                            *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));\n\n                            uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);\n                            CUTE_UNROLL\n                            for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {\n                                ku::nve4m3x2 data_fp8[4];\n                                ku::nvbf16x2 data_bf16[4];\n                                *(uint64_t*)data_fp8 = cur_data_fp8x8;\n                                if (local_col_idx+1 < COLS_PER_GROUP)\n                                    cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);\n                                bf16 scale = scales[local_col_idx];\n                                CUTE_UNROLL\n                                for (int i = 0; i < 4; ++i) {\n                                    data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));\n                                }\n                                if (local_row_idx == 0 && local_col_idx == 0) {\n                                    smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);\n                                }\n                                st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);\n                            }\n                        }\n\n                        fence_view_async_shared();  // NOTE Should we use shared::cluster here?\n                        __syncwarp();\n                        smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();\n                        smem.bar_raw_KV_empty[raw_k_buf_idx].arrive();\n                        if (elect_one_sync()) {\n                            smem.bar_KV_full[k_buf_idx].arrive(0u);\n                        }\n                        rs.update();\n                    }\n                });\n            };\n            if (cta_idx == 0) {\n                launch_dequant_wg(IsCTA0{});\n            } else {\n                launch_dequant_wg(IsCTA1{});\n            }\n        }\n    } else if (warpgroup_idx == 2) {\n        cutlass::arch::warpgroup_reg_dealloc<80>();\n\n        RingBufferState rs;\n        if (warp_idx == 8 && cta_idx == 0 && elect_one_sync()) {\n            // UMMA thread\n            TiledMMA tiled_mma_P = TiledMMA_P{};\n            TiledMMA tiled_mma_O = TiledMMA_O{};\n            Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<H_Q/2>, Int<B_TOPK*2>>{});\n            Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<H_Q/2>, Int<D_V>>{});\n            Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(\n                partition_shape_A(tiled_mma_P, Shape<Int<H_Q/2>, Int<D_Q/2>>{})\n            );\n            tP.data().get() = tmem_cols::P;\n            tO.data().get() = tmem_cols::O;\n            tQ.data().get() = tmem_cols::Q;\n            \n            run_outer_loop([&](const OuterloopArgs &args) {\n                smem.bar_tQ_full.wait(args.outer_loop_phase);\n\n                // Issue P = Q K^T\n                auto issue_P = [&](int k, int rs_offset) {\n                    auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();\n                    auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();\n                    smem.bar_P_empty.wait(bar_phase^1);\n                    if constexpr (IS_PREFILL) {\n                        smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_K*sizeof(bf16));\n                    } else {\n                        // RoPE only\n                        smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));\n                    }\n                    smem.bar_KV_full[k_buf_idx].wait(k_bar_phase);\n                    ku::tcgen05_after_thread_sync();\n                    Tensor sK = make_tensor(\n                        make_smem_ptr(smem.K[k_buf_idx].data()),\n                        ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>()\n                    );\n                    ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);\n                    ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_QK_done, 1|2);\n                };\n\n                // Issue O += S V\n                auto issue_O = [&](int k, int rs_offset) {\n                    auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();\n                    auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();\n                    smem.bar_S_O_full.wait(bar_phase);\n                    if (k == args.start_block_idx) {\n                        smem.bar_tOut_empty.wait(args.outer_loop_phase^1);\n                    }\n                    ku::tcgen05_after_thread_sync();\n                    Tensor sS = make_tensor(\n                        make_smem_ptr(smem.S.data()),\n                        ku::make_umma_canonical_k_major_layout<H_Q/2, B_TOPK, 0>()\n                    );\n                    Tensor sV = make_tensor(\n                        make_smem_ptr(smem.K[k_buf_idx].data()),\n                        ku::make_umma_canonical_mn_major_layout<D_V/2, B_TOPK, 128>()\n                    );\n                    ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == args.start_block_idx);\n                    ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_SV_done, 1|2);\n                    ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_KV_empty[k_buf_idx], 1|2);\n                };\n\n                CUTE_NO_UNROLL\n                for (int k = args.start_block_idx; k < args.end_block_idx+1; ++k) {\n                    if (k < args.end_block_idx) {\n                        issue_P(k, 0);\n                    }\n                    if (k == args.end_block_idx-1) {\n                        ku::umma_arrive_2x1SM_noelect(smem.bar_tQ_empty);\n                    }\n\n                    if (k > args.start_block_idx) {\n                        issue_O(k-1, -1);\n                    }\n                    \n                    if (k != args.end_block_idx) {\n                        rs.update();\n                    }\n                }\n                ku::tcgen05_before_thread_sync();\n                ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tOut_full, 1|2);\n            });\n        } else if (warp_idx == 8 && cta_idx == 1 && elect_one_sync()) {\n            if constexpr (IS_DECODE) {\n                // KV RoPE fetching warp\n                run_outer_loop([&](const OuterloopArgs &args) {\n                    CUTE_NO_UNROLL\n                    for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                        auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                        auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();\n                        smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);\n                        smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);\n                        CUTE_UNROLL\n                        for (int row = 0; row < B_TOPK; row += 4) {\n                            int4 cur_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row);\n                            ku::tma_gather4_cta_group_2<true>(\n                                block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,\n                                smem.bar_KV_full[k_buf_idx],\n                                smem.K[k_buf_idx].data() + (D_NOPE-D_K/2)*B_TOPK + row*D_ROPE,\n                                0,\n                                cur_indices,\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                        }\n                        smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();\n                        rs.update();\n                    }\n                });\n            }\n        } else if (warp_idx == 9) {\n            // KV validness loading warp (for prefill), Indices transformation warp (for decode, Responsible for generating: TMA coordinates, scale factors, and valid masks)\n            if constexpr (IS_PREFILL) {\n                if (lane_idx < B_TOPK/8) {\n                    run_outer_loop([&](const OuterloopArgs &args) {\n                        int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;\n                        CUTE_NO_UNROLL\n                        for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {\n                            char k_validness_mask = load_indices_and_generate_mask(\n                                lane_idx,\n                                gIndices + k*B_TOPK,\n                                params.s_kv,\n                                k*B_TOPK,\n                                args.topk_length\n                            );\n                            \n                            auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                            smem.bar_valid_coord_scales_empty[indices_buf_idx].wait(indices_bar_phase^1);\n                            smem.is_k_valid[indices_buf_idx][lane_idx] = k_validness_mask;\n                            smem.bar_valid_coord_scales_full[indices_buf_idx].arrive();\n                            \n                            rs.update();\n                        }\n                    });\n                }\n            } else {\n                static_assert(B_TOPK == 64);\n                // Each thread is responsible for 2 tokens\n                static constexpr int tma_coords_step_per_token = 576/TMA_K_STRIDE_FOR_DECODING;\n                int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE_FOR_DECODING; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512\n                int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE_FOR_DECODING;\n                uint8_t* k_scales_ptr = (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);\n                uint8_t* extra_k_scales_ptr = (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);\n                \n                run_outer_loop([&](const OuterloopArgs &args) {\n                    int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*args.s_q_idx;\n                    int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*args.s_q_idx;\n                    \n                    struct IsOrigBlock {};\n                    struct IsExtraBlock {};\n                    auto process_one_block = [&](int block_idx, auto is_extra_block_t) {\n                        auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                        static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;\n                        int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;\n                        int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;\n                        [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;\n                        uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;\n                        int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;\n\n                        int abs_pos, my_indices[2];\n                        if (!IS_EXTRA_BLOCK) {\n                            abs_pos = block_idx*B_TOPK + lane_idx*2;\n                            *(int2*)my_indices = __ldg((int2*)(indices + abs_pos));\n                        } else {\n                            abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;\n                            *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));\n                        }\n                        smem.bar_valid_coord_scales_empty[index_buf_idx].wait(index_bar_phase^1);\n\n                        int tma_coords[2];\n                        fp8_e8m0 scales[2*(NUM_SCALES_EACH_TOKEN/2)];\n                        char valid_mask = 0;\n                        CUTE_UNROLL\n                        for (int i = 0; i < 2; ++i) {\n                            int block_idx, idx_in_block;\n                            block_idx = (unsigned int)my_indices[i] / cur_block_size;\n                            idx_in_block = (unsigned int)my_indices[i] % cur_block_size;\n                            bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));\n                            valid_mask |= is_token_valid << i;\n                            tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.\n\n                            int64_t offset = block_idx*cur_k_block_stride + (idx_in_block*8 + (cta_idx == 1 ? 4 : 0)); // Each token has 7 scale factors with an extra 1B padding\n                            uint32_t scalesx4 = is_token_valid ? __ldg((uint32_t*)(cur_k_scales_ptr + offset)) : 0;\n                            *(uint32_t*)(scales+i*(NUM_SCALES_EACH_TOKEN/2)) = scalesx4;\n                        }\n                        valid_mask <<= lane_idx%4*2;\n                        valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);\n                        valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);\n                        *(uint64_t*)(smem.scales[index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;\n                        *(int2*)(smem.tma_coord[index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;\n                        if (lane_idx%4 == 0)\n                            smem.is_k_valid[index_buf_idx][lane_idx/4] = valid_mask;\n                        \n                        smem.bar_valid_coord_scales_full[index_buf_idx].arrive();\n                        rs.update();\n                    };\n\n                    CUTE_NO_UNROLL\n                    for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {\n                        process_one_block(block_idx, IsOrigBlock{});\n                    }\n\n                    CUTE_NO_UNROLL\n                    for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {\n                        process_one_block(block_idx, IsExtraBlock{});\n                    }\n                });\n            }\n        } else if (warp_idx >= 10 && elect_one_sync()) {\n            if constexpr (IS_PREFILL) {\n                if (warp_idx == 10) {\n                    // CLC Producer thread\n                    run_outer_loop([&](const OuterloopArgs &args) {\n                        if (cta_idx == 0) {\n                            smem.bar_clc_empty.wait(args.outer_loop_phase^1);\n                            ku::issue_clc_query_multicast_cluster_all(smem.bar_clc_full, smem.clc_response_obj);\n                        }\n                        smem.bar_clc_full.arrive_and_expect_tx(sizeof(smem.clc_response_obj));\n                    });\n                }\n            } else {\n                // Raw KV NoPE Producer thread\n                run_outer_loop([&](const OuterloopArgs &args) {\n                    CUTE_NO_UNROLL\n                    for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                        auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();\n                        auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                        smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);\n                        smem.bar_raw_KV_empty[raw_k_buf_idx].wait(raw_k_bar_phase^1);\n\n                        int4 nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + (warp_idx == 10 ? 0 : 4));\n                        CUTE_UNROLL\n                        for (int row = (warp_idx == 10 ? 0 : 4); row < B_TOPK; row += 8) {\n                            int4 cur_indices = nxt_indices;\n                            if (row+8 < B_TOPK)\n                                nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row + 8);\n                            ku::tma_gather4(\n                                block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,\n                                smem.bar_raw_KV_full[raw_k_buf_idx],\n                                smem.K_raw[raw_k_buf_idx].data() + row*(D_K/2),\n                                cta_idx*(D_K/2),\n                                cur_indices,\n                                (int64_t)TMA::CacheHintSm90::EVICT_LAST\n                            );\n                        }\n                        if (warp_idx == 10) {\n                            smem.bar_raw_KV_full[raw_k_buf_idx].arrive_and_expect_tx(B_TOPK*(D_K/2)*sizeof(fp8_e4m3));\n                        }\n                        smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();\n                        rs.update();\n                    }\n                });\n            }\n        }\n    } else {\n        // Scale & Exp threads\n        cutlass::arch::warpgroup_reg_alloc<176>();\n\n        int local_warp_idx = warp_idx - 12;\n        bf16* sS_base = smem.S.data() + (local_warp_idx >= 2 ? (H_Q/2)*(B_TOPK/2) : 0) + (idx_in_warpgroup%64)*8;\n\n        RingBufferState rs;\n        run_outer_loop([&](const OuterloopArgs &args) {\n            // For definition and consistency about `mi`, `li`, and `real_mi`, plz refer to head64 prefill\n            float mi = MAX_INIT_VAL;\n            float li = 0.0f;\n            float real_mi = -CUDART_INF_F;\n            static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;\n\n            CUTE_NO_UNROLL\n            for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {\n                auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();\n                auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();\n                auto [_, bar_phase] = rs.get<1>();\n                // NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf\n                smem.bar_valid_coord_scales_full[indices_buf_idx].wait(indices_bar_phase);\n\n                // Get P from TMEM\n                float p[NUM_ELEMS_PER_THREAD];\n                smem.bar_QK_done.wait(bar_phase);\n                ku::tcgen05_after_thread_sync();\n                retrieve_mask_and_reduce_p<\n                    NUM_ELEMS_PER_THREAD,\n                    tmem_cols::P,\n                    barrier_ids::WG2_WARP02_SYNC,\n                    barrier_ids::WG2_WARP13_SYNC,\n                    false\n                >(\n                    smem.is_k_valid[indices_buf_idx],\n                    local_warp_idx,\n                    lane_idx,\n                    [&]() {smem.bar_P_empty.arrive(0u);},\n                    smem.P_exchange,\n                    p\n                );\n\n                // Get rowwise max of P\n                float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);\n                cur_pi_max *= params.sm_scale_div_log2;\n\n                smem.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;\n                NamedBarrier::arrive_and_wait(64, barrier_ids::WG2_WARP02_SYNC + (local_warp_idx&1));\n                cur_pi_max = max(cur_pi_max, smem.rowwise_max_buf[idx_in_warpgroup^64]);\n                real_mi = max(real_mi, cur_pi_max);\n                bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);\n\n\n                // Calc scale factor, and scale li\n                float new_max, scale_for_old;\n                if (!should_scale_o) {\n                    // Don't scale O\n                    scale_for_old = 1.0f;\n                    new_max = mi;\n                } else {\n                    new_max = max(cur_pi_max, mi);\n                    scale_for_old = exp2f(mi - new_max);\n                }\n                mi = new_max;   // mi is still identical within each row\n\n                // Calculate S\n                nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];\n                float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);\n                li = fmaf(li, scale_for_old, cur_sum);\n\n                // Store S\n                smem.bar_SV_done.wait(bar_phase^1);\n                CUTE_UNROLL\n                for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; ++i) {\n                    ku::st_shared(sS_base + i*8*(H_Q/2), *(__int128_t*)(s + i*4));\n                }\n\n                // Rescale O\n                if (k > 0 && should_scale_o) {\n                    ku::tcgen05_after_thread_sync();\n                    rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);\n                    ku::tcgen05_before_thread_sync();\n                }\n\n                fence_view_async_shared();\n                smem.bar_S_O_full.arrive(0u);\n                smem.bar_valid_coord_scales_empty[indices_buf_idx].arrive();\n\n                rs.update();\n            }\n\n            if (real_mi == -CUDART_INF_F) {\n                // real_mi == -CUDART_INF_F <=> No valid TopK indices\n                // We set li to 0 to fit the definition that li := exp(x[i] - mi)\n                li = 0.0f;\n                mi = -CUDART_INF_F;\n            }\n\n            // Reduce li\n            smem.bar_li_empty.wait(args.outer_loop_phase^1);\n            smem.rowwise_li_buf[idx_in_warpgroup^64] = li;\n            NamedBarrier::arrive_and_wait(128, barrier_ids::WG2_SYNC);\n            li += smem.rowwise_li_buf[idx_in_warpgroup];\n\n            if (idx_in_warpgroup < H_Q/2) {\n                // Calculate output_scale and save\n                int head_idx = cta_idx*(H_Q/2) + idx_in_warpgroup;\n                float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + head_idx);\n                float output_scale;\n                if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {\n                    output_scale = __fdividef(1.0f, li + exp2f(fmaf(attn_sink, CUDART_L2E_F, -mi)));\n                } else {\n                    output_scale = __fdividef(1.0f, li);\n                }\n                smem.rowwise_li_buf[idx_in_warpgroup] = li == 0.0f ? 0.0f : output_scale;\n                smem.bar_li_full.arrive();\n\n                float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));\n                cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;\n                if constexpr (IS_PREFILL) {\n                    int global_index = args.s_q_idx*params.h_q + head_idx;\n                    params.max_logits[global_index] = real_mi*CUDART_LN2_F;\n                    params.lse[global_index] = cur_lse;\n                } else {\n                    if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {\n                        params.lse[args.batch_idx*params.stride_lse_b + args.s_q_idx*params.stride_lse_s_q + head_idx] = cur_lse;\n                    } else {\n                        float cur_lse_2base = log2f(li) + mi;\n                        params.lse_accum[args.n_split_idx*params.stride_lse_accum_split + args.s_q_idx*params.stride_lse_accum_s_q + head_idx] = cur_lse_2base;\n                    }\n                }\n\n            }\n        });\n    }\n\n    ku::barrier_cluster_arrive_relaxed();\n    ku::barrier_cluster_wait_acquire();\n\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm100\");\n    }\n#endif\n}\n\n// We have two launchers with different kernel names to distinguish prefill and decode\n\ntemplate<typename Kernel>\nstatic __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)\nsparse_attn_fwd_for_small_topk_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {\n    Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);\n}\n\ntemplate<typename Kernel>\nstatic __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)\nflash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {\n    Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);\n}\n\ntemplate<FwdMode FWD_MODE, int D_QK>\nvoid KernelTemplate<FWD_MODE, D_QK>::run(const ArgT& params) {\n    static_assert(D_QK == 576 || D_QK == 512);\n\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.topk % B_TOPK == 0);   // To save some boundry checkings\n    KU_ASSERT(params.h_q == H_Q);  // To save some calculation\n    KU_ASSERT(params.d_qk == D_QK);\n\n    static_assert(D_Q == 512);\n    CUtensorMap tensor_map_q;\n    if constexpr (IS_DECODE) {\n        KU_ASSERT(params.stride_q_b % params.stride_q_s_q == 0, \"In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension).\");\n        tensor_map_q = ku::make_tensor_map(\n            {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.b * (params.stride_q_b / params.stride_q_s_q)},\n            ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),\n            {64, H_Q/2, 2, (D_Q/64)/2, 1},\n            params.q,\n            CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );\n    } else {\n        tensor_map_q = ku::make_tensor_map(\n            {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.s_q},\n            ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),\n            {64, H_Q/2, 2, (D_Q/64)/2, 1},\n            params.q,\n            CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );  // We use this layout to group Q[0:64] and Q[256:256+64] together, for UTCCP for dual gemm\n    }\n\n    CUtensorMap tensor_map_kv;\n    CUtensorMap tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope = {}, tensor_map_extra_kv_rope = {};\n    if constexpr (IS_DECODE) {\n        auto get_kv_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t stride_kv_block, int64_t stride_kv_row) -> std::pair<CUtensorMap, CUtensorMap> {\n            KU_ASSERT((int64_t)k_ptr % 16 == 0, \"The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f\", is_extra?\"extra_\":\"\", k_ptr);\n            KU_ASSERT(stride_kv_block % TMA_K_STRIDE_FOR_DECODING == 0, \"%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary\", is_extra?\"extra_\":\"\", stride_kv_block, TMA_K_STRIDE_FOR_DECODING);\n            CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(\n                {D_NOPE + D_ROPE*2, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},\n                {TMA_K_STRIDE_FOR_DECODING},\n                {D_K/2, 1},\n                k_ptr,\n                CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8,\n                CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,\n                CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B\n            );  // NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens.\n            CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(\n                {D_ROPE, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},\n                {TMA_K_STRIDE_FOR_DECODING},\n                {64, 1},\n                (uint8_t*)k_ptr + D_NOPE,\n                CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n                CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,\n                CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B\n            );\n            return {tensor_map_kv_nope, tensor_map_kv_rope};\n        };\n        std::tie(tensor_map_kv_nope, tensor_map_kv_rope) = get_kv_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block, params.stride_kv_row);\n        if (params.extra_topk > 0)\n            std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_kv_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block, params.stride_extra_kv_row);\n    } else {\n        tensor_map_kv = ku::make_tensor_map(\n            {D_QK, (unsigned long)params.s_kv}, \n            {(unsigned long)params.stride_kv_s_kv*sizeof(bf16)},\n            {64, 1},\n            params.kv,\n            CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );\n    }\n\n    CUtensorMap tensor_map_o;\n    if constexpr (IS_DECODE) {\n        tensor_map_o = ku::make_tensor_map(\n            {64, H_Q, D_V/64, (unsigned long)params.s_q, (unsigned long)params.b},\n            ku::make_stride_helper<int>({params.stride_o_h_q, 64, params.stride_o_s_q, params.stride_o_b}, sizeof(bf16)),\n            {64, H_Q/2, D_V/64, 1, 1},\n            params.out,\n            CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );\n    } else {\n        tensor_map_o = ku::make_tensor_map(\n            {64, H_Q, D_V/64, (unsigned long)params.s_q, 1ul},\n            ku::make_stride_helper<int>({D_V, 64, H_Q*D_V, H_Q*D_V}, sizeof(bf16)),\n            {64, H_Q/2, D_V/64, 1, 1},\n            params.out,\n            CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );\n    }\n\n\n    CUtensorMap tensor_map_o_accum = {};\n    if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {\n        tensor_map_o_accum = ku::make_tensor_map(\n            {32, H_Q, D_V/32, (unsigned long)params.s_q, (unsigned long)params.num_sm_parts + params.b},\n            ku::make_stride_helper<int>({params.stride_o_accum_h_q, 32, params.stride_o_accum_s_q, params.stride_o_accum_split}, sizeof(float)),\n            {32, H_Q/2, B_EPI_SPLITKV/32, 1, 1},\n            params.o_accum,\n            CU_TENSOR_MAP_DATA_TYPE_FLOAT32,\n            CU_TENSOR_MAP_SWIZZLE_128B,\n            CU_TENSOR_MAP_L2_PROMOTION_L2_256B\n        );\n    }\n\n    TmaParams tma_params;\n    if constexpr (IS_DECODE) {\n        tma_params = {\n            tensor_map_q,\n            tensor_map_o,\n            tensor_map_o_accum,\n            tensor_map_kv_nope,\n            tensor_map_kv_rope,\n            tensor_map_extra_kv_nope,\n            tensor_map_extra_kv_rope\n        };\n    } else {\n        tma_params = {\n            tensor_map_q,\n            tensor_map_kv,\n            tensor_map_o\n        };\n    }\n    \n    auto kernel = IS_PREFILL ? &sparse_attn_fwd_for_small_topk_kernel<KernelTemplate<FWD_MODE, D_QK>> : &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<FWD_MODE, D_QK>>;\n    constexpr size_t smem_size = sizeof(SharedMemoryPlan);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    dim3 grid_shape;\n    if constexpr (IS_DECODE) {\n        grid_shape = dim3(2*params.s_q, FWD_MODE == FwdMode::DecodeWithSplitKV ? params.num_sm_parts : params.b, 1);\n    } else {\n        grid_shape = dim3(2*params.s_q, 1, 1);\n    }\n\n    cutlass::ClusterLaunchParams launch_params = {\n        grid_shape,\n        dim3(NUM_THREADS, 1, 1),\n        dim3(2, 1, 1),\n        smem_size,\n        params.stream\n    };\n    KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(\n        launch_params, (void*)kernel, params, tma_params\n    ));\n}\n\ntemplate<FwdMode FWD_MODE, int D_QK>\nvoid run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params) {\n    using Kernel = KernelTemplate<FWD_MODE, D_QK>;\n    Kernel::run(params);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate<SparseAttnFwdMode FWD_MODE, int D_QK>\nvoid run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/config.h",
    "content": "#pragma once\n\nnamespace Config {\n\nstatic constexpr int BLOCK_SIZE_M = 64;\nstatic constexpr int PAGE_BLOCK_SIZE = 64;\n\nstatic constexpr int HEAD_DIM_K = 576;\nstatic constexpr int HEAD_DIM_V = 512;\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/instantiations/bf16.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n#include \"../splitkv_mla.h\"\n\nnamespace sm90 {\n\ntemplate void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/instantiations/fp16.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n#include \"../splitkv_mla.h\"\n\nnamespace sm90 {\n\n#ifndef FLASH_MLA_DISABLE_FP16\ntemplate void run_flash_splitkv_mla_kernel<cutlass::half_t>(DenseAttnDecodeParams &params);\n#endif\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/splitkv_mla.cuh",
    "content": "#include <cutlass/cutlass.h>\n\n#include \"utils.h\"\n\n#include \"params.h\"\n#include \"config.h\"\n#include \"traits.h\"\n\nusing namespace cute;\nusing cutlass::arch::NamedBarrier;\n\nnamespace sm90 {\n\n// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking\n// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)\n// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM\nstatic constexpr float MAX_INIT_VAL_SM = -1e30f;\nstatic constexpr float MAX_INIT_VAL = -1e33f;\n\n__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {\n    // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx\n    // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a\n    int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);\n    return row_idx;\n}\n\n// Launch TMA copy for a range of KV tile\n// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64\ntemplate<\n    int START_HEAD_DIM_TILE_IDX,\n    int END_HEAD_DIM_TILE_IDX,\n    typename TMA_K_OneTile,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1\n>\n__forceinline__ __device__ void launch_kv_tiles_copy_tma(\n    Tensor<Engine0, Layout0> const &gKV,\t// (PAGE_BLOCK_SIZE, HEAD_DIM_K)\n    Tensor<Engine1, Layout1> &sKV,\t// (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled\n    TMA_K_OneTile &tma_K,\n    TMABarrier* barriers_K,\n    int idx_in_warpgroup\n) {\n    if (idx_in_warpgroup == 0) {\n        auto thr_tma = tma_K.get_slice(_0{});\n        Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});\n        Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});\n        cute::copy(tma_K.with(reinterpret_cast<typename TMABarrier::ValueType &>(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV);\n        if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {\n            launch_kv_tiles_copy_tma<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup);\n        }\n    }\n}\n\n// Prefetch some KV tiles\n// Currently this is not used because it leads to performance degradation\ntemplate<\n    int START_HEAD_DIM_TILE_IDX,\n    int END_HEAD_DIM_TILE_IDX,\n    typename TMA_K_OneTile,\n    typename Engine0, typename Layout0\n>\n__forceinline__ __device__ void prefetch_kv_tiles(\n    Tensor<Engine0, Layout0> const &gKV,\t// (PAGE_BLOCK_SIZE, HEAD_DIM_K)\n    TMA_K_OneTile &tma_K,\n    int idx_in_warpgroup\n) {\n    if (idx_in_warpgroup == 0) {\n        auto thr_tma = tma_K.get_slice(_0{});\n        Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});\n        cute::prefetch(tma_K, cur_gKV);\n        if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {\n            prefetch_kv_tiles<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, tma_K, idx_in_warpgroup);\n        }\n    }\n}\n\n// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h\n// * Copyright (c) 2024, Tri Dao.\ntemplate <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\n__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {\n    constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;\n    // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const\n    if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (arrive) {\n        warpgroup_arrive();\n    }\n    if constexpr (zero_init) {\n        tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    } else {\n        // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    }\n    if constexpr (commit) {\n        warpgroup_commit_batch();\n    }\n    if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n}\n\n\n// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64)\n// The Q-tile should be in shared memory\ntemplate<\n    typename TiledMMA,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n> \n__forceinline__ __device__ void qkt_gemm_one_tile_sQ(\n    TiledMMA &tiled_mma,\n    Tensor<Engine0, Layout0> const &thr_mma_sQ_tile,\t// (MMA, 1, 4)\n    Tensor<Engine1, Layout1> const &thr_mma_sKV_tile,\t// (MMA, 1, 4)\n    Tensor<Engine2, Layout2> &rP,\t// ((2, 2, 8), 1, 1)\n    TMABarrier* barrier,\n    bool &cur_phase,\n    int idx_in_warpgroup\n) {\n    if (idx_in_warpgroup == 0) {\n        barrier->arrive_and_expect_tx(64*64*2);\n    }\n    barrier->wait(cur_phase ? 1 : 0);\n\n    warpgroup_fence_operand(rP);\n    warpgroup_arrive();\n    cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);\n    tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);\n    cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);\n    cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);\n    warpgroup_commit_batch();\n    warpgroup_fence_operand(rP);\n}\n\ntemplate<\n    typename TiledMMA,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n> \n__forceinline__ __device__ void  qkt_gemm_one_tile_rQ(\n    TiledMMA &tiled_mma,\n    Tensor<Engine0, Layout0> const &thr_mma_rQ_tile,\t// (MMA, 1, 4)\n    Tensor<Engine1, Layout1> const &thr_mma_sKV_tile,\t// (MMA, 1, 4)\n    Tensor<Engine2, Layout2> &rP,\t// ((2, 2, 8), 1, 1)\n    TMABarrier* barrier,\n    bool &cur_phase,\n    int idx_in_warpgroup\n) {\n    if (idx_in_warpgroup == 0) {\n        barrier->arrive_and_expect_tx(64*64*2);\n    }\n    barrier->wait(cur_phase ? 1 : 0);\n\n    warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));\n    warpgroup_fence_operand(rP);\n    warpgroup_arrive();\n    cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);\n    tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);\n    cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);\n    cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);\n    warpgroup_commit_batch();\n    warpgroup_fence_operand(rP);\n    warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));\n}\n\n// Pipelined TMA wait and Q K^T gemm\n// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows:\n// - Wait for the 0-th tile to be ready using `barrier.wait()`\n// - Compute Q K^T for the 0-th tile\n// - Wait for the 1-st tile to be ready\n// - Compute Q K^T for the 1-st tile\n// ...\n// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation\ntemplate<\n    typename T, // Traits\n    int PHASE_IDX,\t// See comments in the code\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2,\n    typename Engine3, typename Layout3\n> \n__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm(\n    Tensor<Engine0, Layout0> &sQ,\t// (BLOCK_SIZE_M, HEAD_DIM_K)\n    Tensor<Engine1, Layout1> &sKV,\t// (PAGE_BLOCK_SIZE, HEAD_DIM_K)\n    Tensor<Engine2, Layout2> &rP,\t// ((2, 2, 8), 1, 1)\n    Tensor<Engine3, Layout3> &rQ8,\t// The 8-th tile of Q. We store it separately to leave some room for storing sP1\n    TMABarrier* barriers,\n    bool &cur_phase,\n    int idx_in_warpgroup\n) {\n    Tensor sQ_tiled = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, _64>{})(_, _, _0{}, _);\t// (BLOCK_SIZE_M, 64, 9)\n    Tensor sKV_tiled = flat_divide(sKV, Shape<Int<T::PAGE_BLOCK_SIZE>, _64>{})(_, _, _0{}, _);\t// (PAGE_BLOCK_SIZE, 64, 9)\n    TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){};\n    ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup);\n    Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled);\t// (MMA, 1, 4, 9)\n    Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled);\t// (MMA, 1, 4, 9)\n    TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){};\n\n    #define QKT_GEMM_ONE_TILE(TILE_IDX) \\\n        if constexpr(TILE_IDX != 8) { \\\n            qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int<TILE_IDX>{}), thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \\\n        } else { \\\n            qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \\\n        }\n\n    if constexpr (PHASE_IDX == 0) {\n        // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles\n        tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;\n        tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;\n        QKT_GEMM_ONE_TILE(0);\n        QKT_GEMM_ONE_TILE(1);\n        QKT_GEMM_ONE_TILE(2);\n        QKT_GEMM_ONE_TILE(3);\n    } else if constexpr (PHASE_IDX == 1) {\n        // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles\n        tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;\n        tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;\n        QKT_GEMM_ONE_TILE(4);\n        QKT_GEMM_ONE_TILE(5);\n        QKT_GEMM_ONE_TILE(6);\n        QKT_GEMM_ONE_TILE(7);\n        QKT_GEMM_ONE_TILE(8);\n        QKT_GEMM_ONE_TILE(0);\n        QKT_GEMM_ONE_TILE(1);\n        QKT_GEMM_ONE_TILE(2);\n        QKT_GEMM_ONE_TILE(3);\n        cur_phase ^= 1;\n    } else {\n        // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles\n        static_assert(PHASE_IDX == 2);\n        tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One;\n        tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;\n        QKT_GEMM_ONE_TILE(4);\n        QKT_GEMM_ONE_TILE(5);\n        QKT_GEMM_ONE_TILE(6);\n        QKT_GEMM_ONE_TILE(7);\n        QKT_GEMM_ONE_TILE(8);\n        cur_phase ^= 1;\n    }\n}\n\n\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n> \n__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline(\n    Tensor<Engine0, Layout0> &sQ,\t// (BLOCK_SIZE_M, HEAD_DIM_K)\n    Tensor<Engine1, Layout1> &sKV,\t// (BLOCK_SIZE_M, HEAD_DIM_K)\n    Tensor<Engine2, Layout2> &rP,\t// ((2, 2, 8), 1, 1)\n    int idx_in_warpgroup\n) {\n    TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){};\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ);\t// (MMA, 1, 576/16=36)\n    Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV);\t// (MMA, 1, 576/16=36)\n    gemm<true, -1>(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP);\n}\n\n\n// Compute O += PV, where P resides in register\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n> \n__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP(\n    Tensor<Engine0, Layout0> &rP,\t// ((2, 2, 8), 1, 1), fragment A layout\n    Tensor<Engine1, Layout1> &sKV_half,\t// (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)\n    Tensor<Engine2, Layout2> &rO,\t// ((2, 2, 32), 1, 1)\n    int idx_in_warpgroup\n) {\n    TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){};\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor rP_retiled = make_tensor(rP.data(), Layout<\n        Shape<Shape<_2, _2, _2>, _1, _4>,\n        Stride<Stride<_1, _2, _4>, _0, _8>\n    >{});\n    Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half);\t// (MMA, 1, 64/16=4)\n    gemm<false, -1>(tiled_mma, rP_retiled, thr_mma_sKV_half, rO);\n}\n\n\n// Compute O += PV, where P resides in shared memory\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n> \n__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP(\n    Tensor<Engine0, Layout0> &sP,\n    Tensor<Engine1, Layout1> &sKV_half,\t// (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)\n    Tensor<Engine2, Layout2> &rO,\t// ((2, 2, 32), 1, 1)\n    int idx_in_warpgroup\n) {\n    TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){};\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP);\n    Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half);\t// (MMA, 1, 64/16=4)\n    gemm<false, -1>(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO);\n}\n\n\ntemplate<\n    typename T,\n    bool DO_OOB_FILLING,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2,\n    typename Engine3, typename Layout3,\n    typename Engine4, typename Layout4\n>\n__forceinline__ __device__ void wg0_bunch_0(\n    Tensor<Engine0, Layout0> &rPb,\t// ((2, 2, 8), 1, 1)\n    Tensor<Engine1, Layout1> &rP0,\t// ((2, 2, 8), 1, 1)\n    Tensor<Engine2, Layout2> &rO0,\t// ((2, 2, 32), 1, 1)\n    Tensor<Engine3, Layout3> &sScale0,\t// (BLOCK_SIZE_M)\n    Tensor<Engine4, Layout4> &sM,\t// (BLOCK_SIZE_M)\n    float rL[2],\n    int rRightBorderForQSeq[2],\n    float scale_softmax_log2,\n    int start_token_idx,\n    int idx_in_warpgroup\n) {\n    // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png)\n    CUTLASS_PRAGMA_UNROLL\n    for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n        int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);\n\n        // Mask, and get row-wise max\n        float cur_max = MAX_INIT_VAL;\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {\n            if constexpr (DO_OOB_FILLING) {\n                int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;\n                rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL;\n                rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL;\n            }\n            cur_max = max(cur_max, max(rP0(i), rP0(i+1)));\n        }\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));\n        \n        // Update sM and sL\n        cur_max *= scale_softmax_log2;\n        float new_max = max(sM(row_idx), cur_max);\n        float scale_for_old = exp2f(sM(row_idx) - new_max);\n        __syncwarp();   // Make sure all reads have finished before updating sM\n        if (idx_in_warpgroup%4 == 0) {\n            sScale0(row_idx) = scale_for_old;\n            sM(row_idx) = new_max;\n        }\n        \n        // Scale-O\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {\n            rO0(i) *= scale_for_old;\n            rO0(i+1) *= scale_for_old;\n        }\n\n        // Scale, exp, and get row-wise expsum\n        float cur_sum = 0;\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {\n            rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max);\n            rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max);\n            rPb(i) = (typename T::InputT)rP0(i);\n            rPb(i+1) = (typename T::InputT)rP0(i+1);\n            cur_sum += rP0(i) + rP0(i+1);\n        }\n        rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;\n    }\n}\n\n\ntemplate<\n    typename T,\n    bool IS_BLK0_LAST,\n    bool IS_BLK1_LAST,\n    bool IS_BLK2_LAST,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2,\n    typename Engine3, typename Layout3,\n    typename Engine4, typename Layout4,\n    typename Engine5, typename Layout5\n>\n__forceinline__ __device__ void wg1_bunch_0(\n    Tensor<Engine0, Layout0> &rP1b,\t// ((2, 2, 8), 1, 1)\n    Tensor<Engine1, Layout1> &sScale1,\t// (BLOCK_SIZE_M)\n    Tensor<Engine2, Layout2> &rO1,\t// ((2, 2, 32), 1, 1)\n    Tensor<Engine3, Layout3> &sM,\t// (BLOCK_SIZE_M)\n    float rL[2],\n    int rRightBorderForQSeq[2],\n    Tensor<Engine4, Layout4> const &sScale0,\t// (BLOCK_SIZE_M)\n    Tensor<Engine5, Layout5> &rP1,\t// ((2, 2, 8), 1, 1)\n    float scale_softmax_log2,\n    int start_token_idx,\n    int idx_in_warpgroup\n) {\n    CUTLASS_PRAGMA_UNROLL\n    for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n        int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);\n\n        // Mask, and get row-wise max\n        float cur_max = MAX_INIT_VAL;\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {\n            if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) {\n                // Need to apply the mask when either this block is the last one, or\n                // the next block is the last one (because of the causal mask)\n                int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;\n                rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL;\n                rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL;\n            } else if constexpr (IS_BLK0_LAST) {\n                rP1(i) = rP1(i+1) = MAX_INIT_VAL;\n            }\n            cur_max = max(cur_max, max(rP1(i), rP1(i+1)));\n        }\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));\n        cur_max *= scale_softmax_log2;\n\n        float old_max = sM(row_idx);\n        float new_max = max(old_max, cur_max);\n        float scale_for_old = exp2f(old_max - new_max);\n        __syncwarp();\n        if (idx_in_warpgroup%4 == 0) {\n            sM(row_idx) = new_max;\n            sScale1(row_idx) = scale_for_old;\n        }\n\n        // Scale, exp, and get row-wise expsum\n        float cur_sum = 0;\n        if constexpr (!IS_BLK0_LAST) {\n            CUTLASS_PRAGMA_UNROLL\n            for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {\n                rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max);\n                rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max);\n                rP1b(i) = (typename T::InputT)rP1(i);\n                rP1b(i+1) = (typename T::InputT)rP1(i+1);\n                cur_sum += rP1(i) + rP1(i+1);\n            }\n        }\n\n        // Scale O\n        float cur_scale_for_o1 = scale_for_old * sScale0(row_idx);\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) {\n            rO1(i) *= cur_scale_for_o1;\n            rO1(i+1) *= cur_scale_for_o1;\n        }\n\n        // Update rL\n        rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum;\n    }\n}\n\n\n// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1\n>\n__forceinline__ __device__ void save_rPb_to_sP(\n    Tensor<Engine0, Layout0> &rPb,\n    Tensor<Engine1, Layout1> &sP,\n    int idx_in_warpgroup\n) {\n    auto r2s_copy = make_tiled_copy_C(\n        Copy_Atom<SM90_U32x4_STSM_N, typename T::InputT>{},\n        (typename T::TiledMMA_QK_sQ){}\n    );\n    ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);\n    Tensor thr_copy_rPb = thr_copy.retile_S(rPb);\n    Tensor thr_copy_sP = thr_copy.partition_D(sP);\n    cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);\n}\n\n\n// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1\n>\n__forceinline__ __device__ void retrieve_rP_from_sP(\n    Tensor<Engine0, Layout0> &rPb,\n    Tensor<Engine1, Layout1> const &sP,\n    int idx_in_warpgroup\n) {\n    TiledCopy s2r_copy = make_tiled_copy_A(\n        Copy_Atom<SM75_U32x4_LDSM_N, typename T::InputT>{},\n        (typename T::TiledMMA_PV_LocalP){}\n    );\n    ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);\n    Tensor thr_copy_sP = thr_copy.partition_S(sP);\n    Tensor thr_copy_rPb = thr_copy.retile_D(rPb);\n    cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);\n}\n\n\n// Rescale rP0 and save the result to rPb\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2\n>\n__forceinline__ __device__ void wg0_scale_rP0(\n    Tensor<Engine0, Layout0> const &sScale1,\t// (BLOCK_M)\n    Tensor<Engine1, Layout1> const &rP0,\t\t// ((2, 2, 8), 1, 1)\n    Tensor<Engine2, Layout2> &rPb,\t\t// ((2, 2, 8), 1, 1)\n    int idx_in_warpgroup\n) {\n    CUTLASS_PRAGMA_UNROLL\n    for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n        int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);\n        float scale_factor = sScale1(row_idx);\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {\n            rPb(i) = (typename T::InputT)(rP0(i)*scale_factor);\n            rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor);\n        }\n    }\n}\n\n\n// Rescale rO0 according to sScale1\ntemplate<\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1\n>\n__forceinline__ __device__ void wg0_rescale_rO0(\n    Tensor<Engine0, Layout0> &rO0,\n    Tensor<Engine1, Layout1> &sScale1,\n    float rL[2],\n    int idx_in_warpgroup\n) {\n    CUTLASS_PRAGMA_UNROLL\n    for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n        int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);\n        float scale_factor = sScale1(row_idx);\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {\n            rO0(i) *= scale_factor;\n            rO0(i+1) *= scale_factor;\n        }\n        rL[local_row_idx] *= scale_factor;\n    }\n}\n\n\n// Fill out-of-bound V with 0.0\n// We must fill it since it may contain NaN, which may propagate to the final result\ntemplate<\n    typename T,\n    typename Engine0, typename Layout0\n>\n__forceinline__ __device__ void fill_oob_V(\n    Tensor<Engine0, Layout0> &sV,\t// tile_to_shape(GMMA::Layout_MN_SW128_Atom<InputT>{}, Shape<Int<HALF_HEAD_DIM>, Int<T::PAGE_BLOCK_SIZE>>{}, LayoutRight{} );\n    int valid_window_size,\n    int idx_in_warpgroup\n) {\n    Tensor sV_int64 = make_tensor(\n        make_smem_ptr((int64_t*)(sV.data().get().get())),\n        tile_to_shape(\n            GMMA::Layout_MN_SW128_Atom<cute::int64_t>{},\n            Shape<Int<256/(64/16)>, Int<T::PAGE_BLOCK_SIZE>>{},\n            LayoutRight{}\n        )\n    );\n    valid_window_size = max(valid_window_size, 0);\n    int head_dim_size = size<0>(sV_int64);\t// 128%head_dim_size == 0 should holds\n    for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) {\n        sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0;\n    }\n}\n\n\n// Store O / OAccum\ntemplate<\n    typename T,\n    bool IS_NO_SPLIT,\n    typename TMAParams,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1\n>\n__forceinline__ __device__ void store_o(\n    Tensor<Engine0, Layout0> &rO,\t// ((2, 2, 32), 1, 1)\n    Tensor<Engine1, Layout1> &gOorAccum,\t// (BLOCK_SIZE_M, HEAD_DIM_V)\n    float rL[2],\n    char* sO_addr,\n    TMAParams &tma_params,\n    int batch_idx,\n    int k_head_idx,\n    int m_block_idx,\n    int num_valid_seq_q,\n    int warpgroup_idx,\n    int idx_in_warpgroup\n) {\n    using InputT = typename T::InputT;\n    if constexpr (IS_NO_SPLIT) {\n        // Should convert the output to bfloat16 / float16, and save it to O\n        Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape(\n            GMMA::Layout_K_SW128_Atom<InputT>{},\n            Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}\n        ));\n\n        Tensor rOb = make_tensor_like<InputT>(rO);\n        CUTLASS_PRAGMA_UNROLL\n        for (int idx = 0; idx < size(rO); ++idx) {\n            rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]);\n        }\n\n        Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));\n        TiledCopy r2s_tiled_copy = make_tiled_copy_C(\n            Copy_Atom<SM90_U32x4_STSM_N, InputT>{},\n            (typename T::TiledMMA_PV_LocalP){}\n        );\n        ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);\n        Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);\n        Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);\n        cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);\n        cutlass::arch::fence_view_async_shared();\n        \n        __syncthreads();\n\n        if (threadIdx.x == 0) {\n            Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx);\t// (seqlen_q, HEAD_DIM)\n            auto thr_tma = tma_params.tma_O.get_slice(_0{});\n            Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{})(_, _, m_block_idx, _0{});\n            cute::copy(\n                tma_params.tma_O,\n                thr_tma.partition_S(sOutputBuf),\n                thr_tma.partition_D(my_tma_gO)\n            );\n            cute::tma_store_arrive();\n        }\n    } else {\n        // Should save the result to OAccum\n        Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout<\n            Shape<_64, _512>,\n            Stride<Int<520>, _1>\t// We use stride = 520 here to avoid bank conflict\n        >{});\n    \n        CUTLASS_PRAGMA_UNROLL\n        for (int idx = 0; idx < size(rO); idx += 2) {\n            int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);\n            int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;\n            *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 {\n                rO(idx) / rL[idx%4 >= 2],\n                rO(idx+1) / rL[idx%4 >= 2],\n            };\n        }\n        cutlass::arch::fence_view_async_shared();\n        \n        __syncthreads();\n        \n        int row = threadIdx.x;\n        if (row < num_valid_seq_q) {\n            SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float));\n            cute::tma_store_arrive();\n        }\n    }\n}\n\ntemplate<\n    typename T,\n    typename TmaParams, typename Tensor0\n>\n__forceinline__ __device__ void launch_q_copy(\n    TmaParams const &tma_params,\n    int batch_idx,\n    int m_block_idx,\n    int k_head_idx,\n    Tensor0 &sQ,\n    TMABarrier* barrier_Q\n) {\n    if (threadIdx.x == 0) {\n        Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx);\t// (seqlen_q, HEAD_DIM)\n        auto thr_tma = tma_params.tma_Q.get_slice(_0{});\n        Tensor my_tma_gQ = flat_divide(tma_gQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{})(_, _, m_block_idx, _0{});\n        cute::copy(\n            tma_params.tma_Q.with(reinterpret_cast<typename TMABarrier::ValueType &>(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST),\n            thr_tma.partition_S(my_tma_gQ),\n            thr_tma.partition_D(sQ)\n        );\n        barrier_Q->arrive_and_expect_tx(64*576*2);\n    }\n}\n\ntemplate<\n    typename T,\n    bool IS_R,\n    typename Engine0, typename Layout0\n>\n__forceinline__ __device__ auto get_half_V(\n    Tensor<Engine0, Layout0> &sK\n) {\n    Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){});\n    return flat_divide(sV, Shape<Int<T::HEAD_DIM_V/2>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, Int<(int)IS_R>{}, _0{});\n}\n\ntemplate<\n    typename T,\n    bool IS_BLK0_LAST,\t// \"BLK0\" means block_idx+0, \"BLK1\" means block_idx+1, ...\n    bool IS_BLK1_LAST,\n    typename TMAParams,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2,\n    typename Engine3, typename Layout3,\n    typename Engine4, typename Layout4,\n    typename Engine5, typename Layout5,\n    typename Engine6, typename Layout6,\n    typename Engine7, typename Layout7,\n    typename Engine8, typename Layout8,\n    typename Engine9, typename Layout9,\n    typename Engine10, typename Layout10,\n    typename Engine11, typename Layout11\n>\n__forceinline__ __device__ void wg0_subroutine(\n    Tensor<Engine0, Layout0> &tma_gK,\n    Tensor<Engine1, Layout1> &sQ,\n    Tensor<Engine2, Layout2> &sK0,\n    Tensor<Engine3, Layout3> &sK1,\n    Tensor<Engine4, Layout4> &sP0,\n    Tensor<Engine5, Layout5> &sP1,\n    Tensor<Engine6, Layout6> &sM,\n    Tensor<Engine7, Layout7> &sScale0,\n    Tensor<Engine8, Layout8> &sScale1,\n    Tensor<Engine9, Layout9> &rQ8,\n    Tensor<Engine10, Layout10> &rP0,\n    Tensor<Engine11, Layout11> &rO0,\n    float rL[2],\n    int rRightBorderForQSeq[2],\n    TMABarrier barriers_K0[9],\n    TMABarrier barriers_K1[9],\n    bool &cur_phase_K0,\n    const TMAParams &tma_params,\n    const DenseAttnDecodeParams &params,\n    int* block_table_ptr,\n    int seqlen_k,\n    int block_idx,\n    int end_block_idx,\n    int idx_in_warpgroup\n) {\n    int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;\n    #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx)))\n    int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);\n    int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);\n\n    Tensor sV0L = get_half_V<T, 0>(sK0);\n    Tensor sV1L = get_half_V<T, 0>(sK1);\n\n    Tensor rPb = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});\n    // Calc P0 = softmax(P0)\n    wg0_bunch_0<T, IS_BLK0_LAST||IS_BLK1_LAST>(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup);\n    NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready);\n\n    // Issue rO0 += rPb @ sV0L\n    if constexpr (IS_BLK0_LAST) {\n        fill_oob_V<T>(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup);\n        cutlass::arch::fence_view_async_shared();\n    }\n    warpgroup_cooperative_pv_gemm_localP<T>(rPb, sV0L, rO0, idx_in_warpgroup);\n\n    // Wait for rO0, launch TMA for the next V0L\n    cute::warpgroup_wait<0>();\n    \n    // Wait for warpgroup 1, rescale P0, notify warpgroup 1\n    NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready);\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {\n        // Put it here seems to be faster, don't know why\n        launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);\n    }\n    wg0_scale_rP0<T>(sScale1, rP0, rPb, idx_in_warpgroup);\n    save_rPb_to_sP<T>(rPb, sP0, idx_in_warpgroup);\n    cutlass::arch::fence_view_async_shared();\n    NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready);\n    \n    // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L\n    if constexpr (!IS_BLK0_LAST) {\n        if constexpr (IS_BLK1_LAST) {\n            fill_oob_V<T>(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);\n            cutlass::arch::fence_view_async_shared();\n        }\n        NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);\n        wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup);\n        warpgroup_cooperative_pv_gemm_remoteP<T>(sP1, sV1L, rO0, idx_in_warpgroup);\n    }\n    \n    // Issue P0 = Q @ K0^T\n    // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline.\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {\n        warpgroup_cooperative_qkt_gemm<T, 0>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);\n    }\n\n    // Wait for rO0 += rPb @ sV1L, launch TMA\n    if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) {\n        cute::warpgroup_wait<4>();\n        launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);\n    }\n    \n    // Issue P0 = Q @ K0^T\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {\n        warpgroup_cooperative_qkt_gemm<T, 2>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);\n    }\n    \n    // Wait for P0 = Q @ K0^T\n    cute::warpgroup_wait<0>();\n}\n\n\ntemplate<\n    typename T,\n    bool IS_BLK0_LAST,\n    bool IS_BLK1_LAST,\n    bool IS_BLK2_LAST,\n    typename TMAParams,\n    typename Engine0, typename Layout0,\n    typename Engine1, typename Layout1,\n    typename Engine2, typename Layout2,\n    typename Engine3, typename Layout3,\n    typename Engine4, typename Layout4,\n    typename Engine5, typename Layout5,\n    typename Engine6, typename Layout6,\n    typename Engine7, typename Layout7,\n    typename Engine8, typename Layout8,\n    typename Engine9, typename Layout9,\n    typename Engine10, typename Layout10,\n    typename Engine11, typename Layout11\n>\n__forceinline__ __device__ void wg1_subroutine(\n    Tensor<Engine0, Layout0> &tma_gK,\n    Tensor<Engine1, Layout1> &sQ,\n    Tensor<Engine2, Layout2> &sK0,\n    Tensor<Engine3, Layout3> &sK1,\n    Tensor<Engine4, Layout4> &sP0,\n    Tensor<Engine5, Layout5> &sP1,\n    Tensor<Engine6, Layout6> &sM,\n    Tensor<Engine7, Layout7> &sScale0,\n    Tensor<Engine8, Layout8> &sScale1,\n    Tensor<Engine9, Layout9> &rQ8,\n    Tensor<Engine10, Layout10> &rP1,\n    Tensor<Engine11, Layout11> &rO1,\n    float rL[2],\n    int rRightBorderForQSeq[2],\n    TMABarrier barriers_K0[9],\n    TMABarrier barriers_K1[9],\n    bool &cur_phase_K1,\n    const TMAParams &tma_params,\n    const DenseAttnDecodeParams &params,\n    int* block_table_ptr,\n    int seqlen_k,\n    int block_idx,\n    int end_block_idx,\n    int idx_in_warpgroup\n) {\n    int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;\n    int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);\n    int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);\n\n    Tensor rP1b = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});\n    \n    Tensor sV0R = get_half_V<T, 1>(sK0);\n    Tensor sV1R = get_half_V<T, 1>(sK1);\n\n    // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0\n    NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready);\n    wg1_bunch_0<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup);\n    NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready);\n\n    // Save rPb to sP, and issue rO1 += rP1b @ sV1R\n    // We do this after notifying warpgroup 1, since both \"saving rPb to sP\" and \"issuing\" WGMMA are high-latency operations\n    if constexpr (!IS_BLK0_LAST) {\n        save_rPb_to_sP<T>(rP1b, sP1, idx_in_warpgroup);\n    }\n    if constexpr (!IS_BLK0_LAST) {\n        if constexpr (IS_BLK1_LAST) {\n            fill_oob_V<T>(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);\n            cutlass::arch::fence_view_async_shared();\n        }\n        warpgroup_cooperative_pv_gemm_localP<T>(rP1b, sV1R, rO1, idx_in_warpgroup);\n        if constexpr (!IS_BLK1_LAST) {\n            // We use this proxy for making sP1 visible to the async proxy\n            // We skip it if IS_BLK1_LAST, since in that case we have already put a fence\n            cutlass::arch::fence_view_async_shared();\n        }\n    }\n    \n    // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0\n    NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready);\n    if constexpr (IS_BLK0_LAST) {\n        fill_oob_V<T>(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup);\n        cutlass::arch::fence_view_async_shared();\n    }\n    warpgroup_cooperative_pv_gemm_remoteP<T>(sP0, sV0R, rO1, idx_in_warpgroup);\n    if constexpr (!IS_BLK0_LAST) {\n        NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);\n    }\n    \n    // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {\n        cute::warpgroup_wait<1>();\n        launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);\n    }\n    \n    // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {\n        cute::warpgroup_wait<0>();\n        launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);\n    }\n\n    if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {\n        // Issue rP1 = sQ @ sK1, wait\n        warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);\n    }\n    \n    // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise\n    // nvcc cannot correctly analyse the loop, and will think that we are using accumulator\n    // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized.\n    // This is also the reason why we put QK^T here, instead of the first operation in the loop\n    cute::warpgroup_wait<0>();\n}\n\n// A helper function for determining the length of the causal mask for one q token\n__forceinline__ __device__ int get_mask_len(const DenseAttnDecodeParams &params, int m_block_idx, int local_seq_q_idx) {\n    int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx;\n    if (global_seq_q_idx < params.q_seq_per_hk) {\n        int s_q_idx = global_seq_q_idx / params.q_head_per_hk;\n        return params.s_q - s_q_idx - 1;\n    } else {\n        // Out-of-bound request, regard as no masks\n        return 0;\n    }\n}\n\ntemplate<typename T, typename TmaParams>\n__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1)\nflash_fwd_splitkv_mla_kernel(__grid_constant__ const DenseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {\n    // grid shape: [\n    // \tnum_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))),\n    // \tnum_kv_heads,\n    // \tnum_sm_parts\n    // ]\n    // An \"sm part\" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx).\n    // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])\n    // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n    const int m_block_idx = blockIdx.x;\n    const int k_head_idx = blockIdx.y;\n    const int partition_idx = blockIdx.z;\n    const int warpgroup_idx = threadIdx.x / 128;\n    const int idx_in_warpgroup = threadIdx.x % 128;\n\n    // Define shared tensors\n    extern __shared__ char wksp_buf[];\n    using SharedMemoryPlan = typename T::SharedMemoryPlan;\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n    Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){});\n    Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){});\n    Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){});\n    Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){});\n    Tensor sP1 = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile\n    Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));\n    Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{}));\n    Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));\n    Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));\n    char* sO_addr = (char*)plan.smem_sK0.data();\t// Overlap with sK0 and sK1\n    \n    // Prefetch TMA descriptors\n    if (threadIdx.x == 0) {\n        cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());\n    }\n\n    // Define TMA stuffs\n    Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _);\n    TMABarrier* barriers_K0 = plan.barriers_K0;\n    TMABarrier* barriers_K1 = plan.barriers_K1;\n    TMABarrier* barrier_Q = &(plan.barrier_Q);\n\n    // Initialize TMA barriers\n    if (threadIdx.x == 0) {\n        barrier_Q->init(1);\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < 9; ++i) {\n            barriers_K0[i].init(1);\n            barriers_K1[i].init(1);\n        }\n        cutlass::arch::fence_view_async_shared();\n    }\n    __syncthreads();\n    bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0;\n\n    \n    DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];\n    if (sched_meta.begin_req_idx >= params.b) return;\n\n    // Copy the first Q\n    launch_q_copy<T>(tma_params, sched_meta.begin_req_idx, m_block_idx, k_head_idx, sQ, barrier_Q);\n\n    #pragma unroll 1\n    for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {\n        constexpr int kBlockN = T::PAGE_BLOCK_SIZE;\n        const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;\n        int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);\n        const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;\n        int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);\n        const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);\n        \n        int rRightBorderForQSeq[2];\n        if (params.is_causal) {\n            // The causal mask looks like:\n            // XXXX\n            // XXXX\n            // ...\n            // XXXX\n            //  XXX\n            //  XXX\n            //  ...\n            //  XXX\n            //   XX\n            //   XX\n            //  ...\n            //   XX\n            // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a \"reduction in the length of the k-sequence.\", and adjust end_block_idx based on it, to save some calculation.\n            // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks\n            // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling\n            int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1);\n            int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN);\n            end_block_idx = batch_idx == sched_meta.end_req_idx ? min(sched_meta.end_block_idx, last_block_in_seq) : last_block_in_seq;\n\n            CUTLASS_PRAGMA_UNROLL\n            for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n                int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);\n                rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE);\n            }\n        } else {\n            rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k;\n        }\n\n        // Define global tensors\n        using InputT = typename T::InputT;\n        InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride;\t// (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1)\n        float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M;\t// (BLOCK_SIZE_M) : (1)\n        int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride;\t// (/) : (1)\n        \n        Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(\n            Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{},\n            make_stride(params.o_row_stride, _1{})\n        ));\n        Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout<\n            Shape<Int<T::BLOCK_SIZE_M>>,\n            Stride<_1>\n        >{});\n\n        // Copy K0 and K1\n        launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x);\n        if (start_block_idx+1 < end_block_idx) {\n            launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);\n            launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);\n        }\n\n        Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V / 2>>{});\t// ((2, 2, 32), 1, 1)\n        float rL[2];\n        rL[0] = rL[1] = 0.0f;\n        \n        // Clear buffers\n        cute::fill(rO, 0.);\n        if (threadIdx.x < size(sM)) {\n            sM[threadIdx.x] = MAX_INIT_VAL_SM;\n        }\n\n        // Wait for Q\n        barrier_Q->wait(cur_phase_Q);\n        cur_phase_Q ^= 1;\n\n        Tensor rQ8 = make_tensor<InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});\n        retrieve_rP_from_sP<T>(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup);\n\n        if (warpgroup_idx == 0) {\n            // Warpgroup 0\n            Tensor rP0 = make_tensor<float>((typename T::rP0Layout){});\n            \n            // NOTE We don't use the pipelined version of Q K^T here since it leads\n            // to a slow-down (or even register spilling, thanks to the great NVCC)\n            // Wait for K0\n            CUTLASS_PRAGMA_UNROLL\n            for (int i = 0; i < 9; ++i) {\n                if (idx_in_warpgroup == 0)\n                    barriers_K0[i].arrive_and_expect_tx(64*64*2);\n                barriers_K0[i].wait(cur_phase_K0);\n            }\n            cur_phase_K0 ^= 1;\n            \n            // Issue P0 = Q @ K0^T, wait\n            if (start_block_idx-16777216 < end_block_idx) {     // NOTE We use this `if` to prevent register spilling\n                warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);\n            }\n            // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0\n            NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);\n            cute::warpgroup_wait<0>();\n\n            #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \\\n                wg0_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST>( \\\n                    tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \\\n                    rQ8, rP0, rO, rL, rRightBorderForQSeq, \\\n                    barriers_K0, barriers_K1, cur_phase_K0, \\\n                    tma_params, params, \\\n                    block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \\\n                );\n\n            int block_idx = start_block_idx;\n            #pragma unroll 1\n            for (; block_idx < end_block_idx-2; block_idx += 2) {\n                LAUNCH_WG0_SUBROUTINE(false, false);\n            }\n\n            if (block_idx+1 < end_block_idx) {\n                LAUNCH_WG0_SUBROUTINE(false, true);\n            } else if (block_idx < end_block_idx) {\n                LAUNCH_WG0_SUBROUTINE(true, false);\n            }\n\n        } else {\n            // Warpgroup 1\n            Tensor rP1 = make_tensor<float>((typename T::rP0Layout){});\n            \n            if (start_block_idx+1 < end_block_idx) {\n                // Issue rP1 = sQ @ sK1, wait\n                warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);\n                cute::warpgroup_wait<0>();\n            }\n\n            #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \\\n                wg1_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>( \\\n                    tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \\\n                    rQ8, rP1, rO, rL, rRightBorderForQSeq, \\\n                    barriers_K0, barriers_K1, cur_phase_K1, \\\n                    tma_params, params, \\\n                    block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \\\n                );\n\n            int block_idx = start_block_idx;\n            #pragma unroll 1\n            for (; block_idx < end_block_idx-3; block_idx += 2) {\n                LAUNCH_WG1_SUBROUTINE(false, false, false);\n            }\n\n            if (block_idx+2 < end_block_idx) {\n                LAUNCH_WG1_SUBROUTINE(false, false, true);\n                block_idx += 2;\n                LAUNCH_WG1_SUBROUTINE(true, false, false);\n            } else if (block_idx+1 < end_block_idx) {\n                LAUNCH_WG1_SUBROUTINE(false, true, false);\n            } else if (block_idx < end_block_idx) {\n                LAUNCH_WG1_SUBROUTINE(true, false, false);\n            }\n        }\n\n        // Reduce rL across threads within the same warp\n        rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);\n        rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);\n        rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);\n        rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);\n\n        // Reduce rL across warpgroups\n        int my_row = get_AorC_row_idx(0, idx_in_warpgroup);\n        if (idx_in_warpgroup%4 == 0) {\n            sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0];\n            sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1];\n        }\n        __syncthreads();\n        if (warpgroup_idx == 0) {\n            rL[0] += sL_reduction_wksp[my_row + 64];\n            rL[1] += sL_reduction_wksp[my_row + 8 + 64];\n        } else {\n            if (idx_in_warpgroup%4 == 0) {\n                sL_reduction_wksp[my_row] += rL[0];\n                sL_reduction_wksp[my_row + 8] += rL[1];\n            }\n            __syncwarp();\n            rL[0] = sL_reduction_wksp[my_row];\n            rL[1] = sL_reduction_wksp[my_row+8];\n        }\n\n        // Prune out when rL is 0.0f or NaN\n        // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads\n        // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error.\n        // When this happens, we set rL to 1.0f. This aligns with the old version\n        // of the MLA kernel.\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < 2; ++i)\n            rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i];\n\n        // Copy Q for the next batch\n        if (batch_idx+1 <= sched_meta.end_req_idx) {\n            launch_q_copy<T>(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q);\n        } else {\n            // Allow the next kernel (the combine kernel) to launch\n            // The next kernel MUST be the combine kernel\n            cudaTriggerProgrammaticLaunchCompletion();\n        }\n\n        int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M);\n        if (is_no_split) {\n            store_o<T, true>(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n            int i = threadIdx.x;\n            if (i < num_valid_seq_q) {\n                float cur_L = sL_reduction_wksp[i];\n                gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E;\n            }\n\n            cute::tma_store_wait<0>();\n        } else {\n            // Don't use __ldg because of PDL and instruction reordering\n            int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx;\n            float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V;\t// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)\n            float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M;\t// (BLOCK_SIZE_M) : (1)\n            Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<\n                Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>,\n                Stride<Int<T::HEAD_DIM_V>, _1>\n            >{});\n            Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout<\n                Shape<Int<T::BLOCK_SIZE_M>>,\n                Stride<_1>\n            >{});\n            store_o<T, false>(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n            int i = threadIdx.x;\n            if (i < num_valid_seq_q) {\n                float cur_L = sL_reduction_wksp[i];\n                gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i);\n            }\n\n            cute::tma_store_wait<0>();\n        }\n\n        if (batch_idx != sched_meta.end_req_idx)\n            __syncthreads();\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm90\");\n    }\n#endif\n}\n\n\ntemplate<typename InputT>\nvoid run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {\n    FLASH_ASSERT(params.d == Config::HEAD_DIM_K);\n    FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);\n\n    using T = Traits<InputT>;\n    auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);\n    auto tma_Q = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((InputT*)params.q_ptr),\n            make_layout(\n                shape_Q,\n                make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride)\n            )\n        ),\n        tile_to_shape(\n            GMMA::Layout_K_SW128_Atom<InputT>{},\n            Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{}\n        )\n    );\n    auto shape_K = make_shape(Int<T::PAGE_BLOCK_SIZE>{}, Int<T::HEAD_DIM_K>{}, params.h_k, params.num_blocks);\n    auto tma_K = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((InputT*)params.k_ptr),\n            make_layout(\n                shape_K,\n                make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride)\n            )\n        ),\n        tile_to_shape(\n            GMMA::Layout_K_SW128_Atom<InputT>{},\n            Layout<\n                Shape<Int<T::PAGE_BLOCK_SIZE>, Int<64>>,\n                Stride<Int<T::HEAD_DIM_K>, _1>\n            >{}\n        )\n    );\n    auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b);\n    auto tma_O = cute::make_tma_copy(\n        SM90_TMA_STORE{},\n        make_tensor(\n            make_gmem_ptr((InputT*)params.o_ptr),\n            make_layout(\n                shape_O,\n                make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride)\n            )\n        ),\n        tile_to_shape(\n            GMMA::Layout_K_SW128_Atom<InputT>{},\n            Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}\n        )\n    );\n    TmaParams<decltype(shape_Q), decltype(tma_Q), decltype(shape_K), decltype(tma_K), decltype(shape_O), decltype(tma_O)> tma_params = {\n        shape_Q, tma_Q,\n        shape_K, tma_K,\n        shape_O, tma_O\n    };\n    auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T, decltype(tma_params)>;\n    constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan);\n    CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)\n    const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);\n    cudaLaunchAttribute mla_kernel_attributes[1];\n    mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;\n    mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1;\n    cudaLaunchConfig_t mla_kernel_config = {\n        dim3(num_m_block, params.h_k, params.num_sm_parts),\n        dim3(T::NUM_THREADS, 1, 1),\n        smem_size,\n        params.stream,\n        mla_kernel_attributes,\n        1\n    };\n    cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);\n    CHECK_CUDA_KERNEL_LAUNCH();\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/splitkv_mla.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90 {\n\ntemplate<typename InputT>\nvoid run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/dense/traits.h",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cutlass/cutlass.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/barrier.h>\n\n#include \"config.h\"\n\nusing TMABarrier = cutlass::arch::ClusterTransactionBarrier;\nusing namespace cute;\n\ntemplate<typename InputT_>\nstruct Traits {\n    using InputT = InputT_;\n    \n    static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;\n    static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;\n    static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;\n    static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;\n\n    static constexpr int NUM_THREADS = 256;\n\n    static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);\n\n    using TiledMMA_QK_sQ = decltype(make_tiled_mma(\n        GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),\n        Layout<Shape<_1, _1, _1>>{}\n    ));\n\n    using TiledMMA_QK_rQ = decltype(make_tiled_mma(\n        GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),\n        Layout<Shape<_1, _1, _1>>{}\n    ));\n\n    using TiledMMA_PV_LocalP = decltype(make_tiled_mma(\n        GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),\n        Layout<Shape<_1, _1, _1>>{}\n    ));\n\n    using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(\n        GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),\n        Layout<Shape<_1, _1, _1>>{}\n    ));\n\n    using SmemLayoutQ = decltype(tile_to_shape(\n        GMMA::Layout_K_SW128_Atom<InputT>{},\n        Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}\n    ));\n\n    using SmemLayoutK = decltype(tile_to_shape(\n        GMMA::Layout_K_SW128_Atom<InputT>{},\n        Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}\n    ));\n\n    using SmemLayoutV = decltype(composition(\n        SmemLayoutK{},\n        make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})\n    ));\t// A transposed version of SmemLayoutK\n\n    using SmemLayoutP0 = decltype(tile_to_shape(\n        GMMA::Layout_K_SW128_Atom<InputT>{},\n        Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}\n    ));\n\n    using rP0Layout = decltype(layout(partition_fragment_C(\n        TiledMMA_QK_sQ{},\n        Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}\n    )));\n\n    struct SharedMemoryPlan {\n        cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;\n        cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;\n        cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;\n        cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;\n        cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;\n        cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;\n        cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;\n        cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;\n        TMABarrier barriers_K0[HEAD_DIM_K/64];\n        TMABarrier barriers_K1[HEAD_DIM_K/64];\n        TMABarrier barrier_Q;\n    };\n\n};\n\ntemplate<\n    typename ShapeQ, typename TMA_Q,\n    typename ShapeK, typename TMA_K,\n    typename ShapeO, typename TMA_O\n>\nstruct TmaParams {\n    ShapeQ shape_Q;\n    TMA_Q tma_Q;\n    ShapeK shape_K;\n    TMA_K tma_K;\n    ShapeO shape_O;\n    TMA_O tma_O;\n};\n\nenum NamedBarriers : int {\n    sScale0Ready = 0,\n    sScale1Ready = 1,\n    sP0Ready = 2,\n    rO1sP0sV0RIssued = 3,\n    sMInitialized = 4,\n};\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/components/config.h",
    "content": "#pragma once\n\n#include <cutlass/numeric_types.h>\n#include <cutlass/arch/barrier.h>\n#include <cute/tensor.hpp>\n#include \"defines.h\"\n\nusing namespace cute;\n\nnamespace sm90::decode::sparse_fp8 {\n\nstatic constexpr int HEAD_DIM_K = 576;\nstatic constexpr int HEAD_DIM_V = 512;\nstatic constexpr int HEAD_DIM_NOPE = HEAD_DIM_V;\nstatic constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V;\nstatic constexpr int QUANT_TILE_SIZE = 128;\nstatic constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE;\nstatic constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16);\nstatic constexpr int PAGE_BLOCK_SIZE = 64;\n\n}"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/components/dequant.h",
    "content": "#pragma once\n\n#include <cuda_fp8.h>\n#include <cuda_bf16.h>\n\n#include \"defines.h\"\n\nnamespace sm90::decode::sparse_fp8 {\n\nstruct fp8x8 {\n    __nv_fp8x4_e4m3 lo;\n    __nv_fp8x4_e4m3 hi;\n};\n\nstruct fp8x16 {\n    fp8x8 lo;\n    fp8x8 hi;\n};\n\n__device__ __forceinline__\nbf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) {\n    #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \\\n    { \\\n        float4 fp32x4 = (float4)(FP8x4); \\\n        OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \\\n        OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \\\n    }\n\n    bf16x8 result;\n    DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);\n    DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);\n\n    return result;\n}\n\nenum class L1CacheHint {\n    NO_ALLOCATE,\n    EVICT_FIRST,\n    EVICT_NORMAL,\n    EVICT_LAST\n};\n\nenum class L2PrefetchHint {\n    B64,\n    B128,\n    B256\n};\n\ntemplate<\n    typename T,\n    L1CacheHint l1_cache_hint,\n    L2PrefetchHint l2_prefetch_hint\n>\n__device__ __forceinline__\nT load_128b_from_gmem(const void* addr) {\n    static_assert(sizeof(T) == 128/8);\n    int4 ret;\n\n    #define EXEC(L1_HINT_STR, L2_HINT_STR) { \\\n        asm volatile(\"ld.global.nc.L1::\" L1_HINT_STR \".L2::\" L2_HINT_STR \".v4.s32 {%0, %1, %2, %3}, [%4];\" \\\n            : \"=r\"(ret.x), \"=r\"(ret.y), \"=r\"(ret.z), \"=r\"(ret.w) \\\n            : \"l\"(addr)); \\\n    }\n\n    #define DISPATCH_L2(L1_HINT_STR) { \\\n        if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \\\n            EXEC(L1_HINT_STR, \"64B\") \\\n        else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \\\n            EXEC(L1_HINT_STR, \"128B\") \\\n        else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \\\n            EXEC(L1_HINT_STR, \"256B\") \\\n    }\n\n    if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)\n        DISPATCH_L2(\"no_allocate\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)\n        DISPATCH_L2(\"evict_first\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)\n        DISPATCH_L2(\"evict_normal\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)\n        DISPATCH_L2(\"evict_last\")\n\n    #undef EXEC\n    #undef DISPATCH_L2\n    return *reinterpret_cast<T*>(&ret);\n}\n\ntemplate<\n    typename T,\n    L1CacheHint l1_cache_hint,\n    L2PrefetchHint l2_prefetch_hint\n>\n__device__ __forceinline__\nT load_64b_from_gmem(const void* addr) {\n    static_assert(sizeof(T) == 64/8);\n    int2 ret;\n\n    #define EXEC(L1_HINT_STR, L2_HINT_STR) { \\\n        asm volatile(\"ld.global.nc.L1::\" L1_HINT_STR \".L2::\" L2_HINT_STR \".v2.s32 {%0, %1}, [%2];\" \\\n            : \"=r\"(ret.x), \"=r\"(ret.y) \\\n            : \"l\"(addr)); \\\n    }\n\n    #define DISPATCH_L2(L1_HINT_STR) { \\\n        if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \\\n            EXEC(L1_HINT_STR, \"64B\") \\\n        else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \\\n            EXEC(L1_HINT_STR, \"128B\") \\\n        else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \\\n            EXEC(L1_HINT_STR, \"256B\") \\\n    }\n\n    if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)\n        DISPATCH_L2(\"no_allocate\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)\n        DISPATCH_L2(\"evict_first\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)\n        DISPATCH_L2(\"evict_normal\")\n    else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)\n        DISPATCH_L2(\"evict_last\")\n\n    #undef EXEC\n    #undef DISPATCH_L2\n    return *reinterpret_cast<T*>(&ret);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/components/helpers.h",
    "content": "#pragma once\n\n#include <cooperative_groups.h>\n#include <cute/tensor.hpp>\n\n#include \"config.h\"\n\nusing namespace cute;\n\nnamespace sm90::decode::sparse_fp8 {\n\n// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx\n// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a\n__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {\n    int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);\n    return row_idx;\n}\n\n\n// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h\ntemplate <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\n__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {\n    constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;\n    // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const\n    if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (arrive) {\n        warpgroup_arrive();\n    }\n    if constexpr (zero_init) {\n        tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    } else {\n        // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    }\n    if constexpr (commit) {\n        warpgroup_commit_batch();\n    }\n    if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n}\n\ntemplate<\n    typename TMA,\n    typename Tensor0,\n    typename Tensor1\n>\nCUTE_DEVICE\nvoid launch_tma_copy(\n    const TMA &tma_copy,\n    const Tensor0 &src,\n    Tensor1 &dst,\n    transac_bar_t &bar,\n    const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL,\n    const uint16_t &multicast_mask = 0\n) {\n    auto thr_tma = tma_copy.get_slice(_0{});\n    cute::copy(\n        tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), multicast_mask, cache_hint),\n        thr_tma.partition_S(src),\n        thr_tma.partition_D(dst)\n    );\n}\n\ntemplate<typename T>\nCUTE_DEVICE\nstatic void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {\n    long2 data_long2 = *reinterpret_cast<const long2*>(&data);\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);\n    asm volatile (\n        \"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \\n\"\n        :\n        : \"r\"(dst_addr), \"l\"(data_long2.x), \"l\"(data_long2.y), \"r\"(mbar_addr)\n    );\n}\n\nCUTE_DEVICE\nstatic void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) {\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);\n    uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr);\n    uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);\n    asm volatile (\n        \"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \\n\"\n        :\n        : \"r\"(dst_addr), \"r\"(src_addr), \"r\"(size), \"r\"(mbar_addr)\n    );\n}\n\nstatic constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK.\ntemplate<typename T>\nCUTE_DEVICE\nT* get_peer_addr(T* p) {\n    return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/config.h",
    "content": "#pragma once\n\n#include <cutlass/numeric_types.h>\n#include <cutlass/arch/barrier.h>\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n#include \"params.h\"\n\nusing namespace cute;\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\nclass KernelTemplate {\npublic:\n\nstatic_assert(NUM_HEADS == 64 || NUM_HEADS == 128);\nstatic constexpr int NUM_M_BLOCKS = NUM_HEADS / 64;\nstatic constexpr int CLUSTER_SIZE = NUM_M_BLOCKS;\n\nstatic constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512;\nstatic constexpr int HEAD_DIM_V = 512;\nstatic constexpr int HEAD_DIM_ROPE = 64;\nstatic constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;\n\nstatic constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;\nstatic constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8;  // For MODEL1: 7 fp8_e4m3 + 1 padding\n\nstatic constexpr int NUM_THREADS = 128*3;\nstatic constexpr int BLOCK_M = 64;\nstatic constexpr int TOPK_BLOCK_SIZE = 64;\nstatic constexpr int NUM_K_BUFS = 2;\n\nusing SmemLayoutQTile = decltype(tile_to_shape(\n    GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},\n    Shape<Int<BLOCK_M>, Int<64>>{}\n));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutQTiles = decltype(tile_to_shape(\n    SmemLayoutQTile{},\n    Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n));\n\nusing SmemLayoutQ = SmemLayoutQTiles<HEAD_DIM_K/64>;\n\nusing SmemLayoutKTile = decltype(tile_to_shape(\n    GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},\n    Shape<Int<TOPK_BLOCK_SIZE>, _64>{},\n    Step<_1, _2>{}\n));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles = decltype(tile_to_shape(\n    SmemLayoutKTile{},\n    Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTilesTransposed = decltype(composition(\n\tSmemLayoutKTiles<NUM_TILES>{},\n\tLayout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}\n));\n\nstatic constexpr int OBUF_SW = 64;\nusing SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom<bf16>;\nusing SmemLayoutOBuf = decltype(tile_to_shape(\n    SmemLayoutOBufAtom{},\n    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},\n    Step<_1, _2>{}\n));\n\nusing SmemLayoutOAccumBuf = Layout<\n    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,\n    Stride<Int<520>, _1>\t// We use stride = 520 here to avoid bank conflict\n>;\n\nusing SmemLayoutK = SmemLayoutKTiles<HEAD_DIM_K/64>;\nusing SmemLayoutV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64>;\nusing SmemLayoutHalfV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64/2>;\n\nusing SmemLayoutS = decltype(tile_to_shape(\n    GMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}\n));\n\nstruct SharedMemoryPlan {\n    array_aligned<bf16, cosize_v<SmemLayoutQ>> q;\n    union {\n        array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];\n        array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;\n        array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;\n    } u;\n    CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;\n    bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];\n\n    float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];\n    transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];\n};\n\ntemplate<\n    typename Shape_Q, typename TMA_Q\n>\nstruct TmaParams {\n    Shape_Q shape_Q; TMA_Q tma_Q;\n    CUtensorMap tensor_map_o;\n};\n\nusing TiledMMA_QK = decltype(make_tiled_mma(\n    GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\nusing TiledMMA_QK_rQ = decltype(make_tiled_mma(\n    GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\nusing TiledMMA_PV_LocalP = decltype(make_tiled_mma(\n    GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\nusing TiledMMA_PV_RemoteP = decltype(make_tiled_mma(\n    GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\n\nenum NamedBarriers : uint32_t {\n    sScale_and_sS_ready = 0,\n    sScale_and_sS_free = 1,\n    oBuf_free_and_sL_ready = 2,\n    epilogue_r2s_ready = 3,\n    batch_loop_sync = 4,\n    warpgroup0_sync = 5\n};\n\n\n// Synchronize all threads within the cluster (which processes one q token)\nstatic __forceinline__ __device__ void sync_all_threads_in_cluster() {\n    if constexpr (CLUSTER_SIZE == 1) {\n        __syncthreads();\n    } else {\n        ku::barrier_cluster_arrive_relaxed();\n        ku::barrier_cluster_wait_acquire();\n    }\n}\n\n// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction\ntemplate<\n    typename Tensor0,\n    typename Tensor1\n>\nstatic __forceinline__ __device__ void save_rPb_to_sP(\n    Tensor0 const &rPb,\n    Tensor1 const &sP,\n    int idx_in_warpgroup\n) {\n    auto r2s_copy = make_tiled_copy_C(\n        Copy_Atom<SM90_U32x4_STSM_N, bf16>{},\n        TiledMMA_QK{}\n    );\n    ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);\n    Tensor thr_copy_rPb = thr_copy.retile_S(rPb);\n    Tensor thr_copy_sP = thr_copy.partition_D(sP);\n    cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);\n}\n\n\ntemplate<\n    bool IS_NO_SPLIT,\n    typename TMAParams,\n    typename Tensor0,\n    typename Tensor1,\n    typename Tensor2,\n    typename Tensor3\n>\nstatic __forceinline__ __device__ void store_o(\n    Tensor0 &rO,\t// ((2, 2, 32), 1, 1)\n    Tensor1 &gOorAccum,\t// (BLOCK_SIZE_M, HEAD_DIM_V)\n    Tensor2 &sOutputBuf,\n    Tensor3 &sOutputAccumBuf,\n    SharedMemoryPlan &plan,\n    float o_scales[2],\n    TMAParams &tma_params,\n    int batch_idx,\n    int s_q_idx,\n    int head_block_idx,\n    int num_valid_seq_q,\n    int warpgroup_idx,\n    int idx_in_warpgroup\n) {\n    using cutlass::arch::NamedBarrier;\n    if constexpr (IS_NO_SPLIT) {\n        // Should convert the output to bfloat16 / float16, and save it to O\n        // Here we don't pipeline STSM and tma store because it's slower\n        Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));\n\n        // Calculate \"base\" ptrs in advance\n        // Each STSM fills a chunk of shape 16x16, while we are using SW-OBUF_SW, so we need OBUF_SW/16 base pointers\n        constexpr int NUM_CHUNKS_IN_SW_ATOM = OBUF_SW/16;\n        bf16* base_output_buf_ptrs[NUM_CHUNKS_IN_SW_ATOM];\n        CUTE_UNROLL\n        for (int i = 0; i < NUM_CHUNKS_IN_SW_ATOM; ++i) {\n            base_output_buf_ptrs[i] = &sMyOutputBuf((idx_in_warpgroup/32)*16+idx_in_warpgroup%16, idx_in_warpgroup%32/16*8 + i*16);\n        }\n\n        CUTE_UNROLL\n        for (int idx = 0; idx < (HEAD_DIM_V/2)/16; idx += 1) {\n            // In each iteration we deal with a chunk of shape 16x16\n            using bf16x2 = __nv_bfloat162;\n            bf16x2 a01 = __float22bfloat162_rn(float2{rO(idx*8+0)*o_scales[0], rO(idx*8+1)*o_scales[0]});\n            bf16x2 a23 = __float22bfloat162_rn(float2{rO(idx*8+2)*o_scales[1], rO(idx*8+3)*o_scales[1]});\n            bf16x2 a45 = __float22bfloat162_rn(float2{rO(idx*8+4)*o_scales[0], rO(idx*8+5)*o_scales[0]});\n            bf16x2 a67 = __float22bfloat162_rn(float2{rO(idx*8+6)*o_scales[1], rO(idx*8+7)*o_scales[1]});\n            SM90_U32x4_STSM_N::copy(\n                *reinterpret_cast<uint32_t*>(&a01),\n                *reinterpret_cast<uint32_t*>(&a23),\n                *reinterpret_cast<uint32_t*>(&a45),\n                *reinterpret_cast<uint32_t*>(&a67),\n                *reinterpret_cast<uint128_t*>(base_output_buf_ptrs[idx%4] + (idx/4*4)*16*64)\n            );\n        }\n\n        cutlass::arch::fence_view_async_shared();\n        NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);\n\n        if (threadIdx.x == 0) {\n            SM90_TMA_STORE_5D::copy(\n                &tma_params.tensor_map_o,\n                plan.u.oBuf.data(),\n                0, head_block_idx*64, 0,\n                s_q_idx, batch_idx\n            );\n            cute::tma_store_arrive();\n        }\n    } else {\n        // Should save the result to OAccum\n        CUTLASS_PRAGMA_UNROLL\n        for (int idx = 0; idx < size(rO); idx += 2) {\n            int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);\n            int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;\n            *(float2*)(&(sOutputAccumBuf(row, col))) = float2 {\n                rO(idx) * o_scales[idx%4>=2],\n                rO(idx+1) * o_scales[idx%4>=2],\n            };\n        }\n        cutlass::arch::fence_view_async_shared();\n        \n        NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);\n        \n        if (elect_one_sync()) {\n            CUTLASS_PRAGMA_UNROLL\n            for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) {\n                int row = local_row * (256/32) + (threadIdx.x / 32);\n                if (row < num_valid_seq_q) {\n                    SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float));\n                }\n            }\n            cute::tma_store_arrive();\n        }\n    }\n}\n\n\ntemplate<typename TMAParams>\nstatic __device__ __forceinline__ void\ndevfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params);\n\nstatic void run(const SparseAttnDecodeParams &params);\n\n};\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 128>(const SparseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 64>(const SparseAttnDecodeParams &params);\n\n}\n\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 128>(const SparseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
    "content": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 64>(const SparseAttnDecodeParams &params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh",
    "content": "#pragma once\n\n#include \"splitkv_mla.h\"\n\n#include <cuda_fp8.h>\n#include <math_constants.h>\n#include <cutlass/barrier.h>\n#include <cutlass/arch/barrier.h>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/cluster_launch.hpp>\n\n#include <kerutils/kerutils.cuh>\n\n#include \"utils.h\"\n#include \"components/dequant.h\"\n#include \"components/helpers.h\"\n#include \"config.h\"\nusing namespace cute;\n\nnamespace sm90::decode::sparse_fp8 {\n\nstatic constexpr float MAX_INIT_VAL = -1e30;    // Prevent (-inf) - (-inf) = nan\nusing cutlass::arch::fence_view_async_shared;\nusing cutlass::arch::NamedBarrier;\nusing fp8_e8m0 = __nv_fp8_e8m0;\n\ntemplate<\n    typename Tensor0,\n    typename Tensor1,\n    typename Tensor2\n>\n__forceinline__ __device__ void scale_softmax(\n    Tensor0 &rP,\n    Tensor1 &rS,\n    Tensor2 &rO,\n    float scale_softmax_log2,\n    float sScale[],\n    float rM[2],\n    float rL[2],\n    bool is_kv_valid[],\n    int block_idx,\n    int idx_in_warpgroup\n) {\n    float scale_for_olds[2];\n    CUTE_UNROLL\n    for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n        Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _));\n        Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _));\n        Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));\n\n        float cur_max = -INFINITY;\n        CUTE_UNROLL\n        for (int i = 0; i < size(cur_rP); ++i) {\n            if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2])\n                cur_rP(i) = -INFINITY;\n            cur_max = max(cur_max, cur_rP(i));\n        }\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));\n        cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));\n\n        cur_max *= scale_softmax_log2;\n        float old_max = rM[local_row_idx];\n        rM[local_row_idx] = max(cur_max, old_max);\n        float scale_for_old = exp2f(old_max - rM[local_row_idx]);\n        scale_for_olds[local_row_idx] = scale_for_old;\n\n        CUTE_UNROLL\n        for (int i = 0; i < size(cur_rO); ++i) {\n            cur_rO(i) *= scale_for_old;\n        }\n\n        float cur_sum = 0;\n        CUTE_UNROLL\n        for (int i = 0; i < size(cur_rP); ++i) {\n            cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]);\n            cur_rS(i) = (bf16)cur_rP(i);\n            cur_sum += cur_rP(i);\n        }\n\n        rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;\n    }\n    if (idx_in_warpgroup%4 == 0)\n        *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds);\n}\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\ntemplate<typename TMAParams>\n__device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params) {\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n    const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x;\n    const int s_q_idx = blockIdx.y;\n    const int partition_idx = blockIdx.z;\n    const int idx_in_cluster = CLUSTER_SIZE == 1 ? 0 : head_block_idx % 2;\n    const int warpgroup_idx = cutlass::canonical_warp_group_idx();\n    const int idx_in_warpgroup = threadIdx.x % 128;\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n\n    // Define shared tensors\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n    Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{});\n    Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{});\n    Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{});\n    Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});\n    float* sM = plan.sM;\n    float* sL = plan.sL;\n    float* sScale = plan.sScale;\n    \n    // Prefetch TMA descriptors\n    if (warp_idx == 0 && elect_one_sync()) {\n        cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);\n    }\n    \n    // Initialize TMA barriers\n    if (warp_idx == 0 && elect_one_sync()) {\n        plan.bar_q.init(1);\n        if constexpr (CLUSTER_SIZE == 2) {\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_K_BUFS; ++i) {\n                plan.bar_k_local_ready[i].init(128);\n                plan.bar_k_remote_ready[i].init(1);\n                plan.bar_k_avail[i].init(4);\n            }\n        } else {\n            CUTE_UNROLL\n            for (int i = 0; i < NUM_K_BUFS; ++i) {\n                plan.bar_k_local_ready[i].init(128);\n                plan.bar_k_avail[i].init(256);\n            }\n        }\n        cutlass::arch::fence_barrier_init();\n    }\n    ku::barrier_cluster_arrive_relaxed();\n\n    int bar_phase_k = 0; // Don't use array here to prevent using local memory\n\n    // Programmatic Dependent Launch: Wait for the previous kernel to finish\n    // Don't use PDL because of compiler bugs!\n    // cudaGridDependencySynchronize();\n    \n    DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];\n\n    if (sched_meta.begin_req_idx >= params.b) return;\n\n    if (warp_idx == 0 && elect_one_sync()) {\n        Tensor gQ = flat_divide(\n            tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, sched_meta.begin_req_idx),\n            Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}\n        )(_, _, head_block_idx, _0{});\n        launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);\n        plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));\n    }\n\n    ku::barrier_cluster_wait_acquire();\n\n    struct MainloopArgs {\n        int start_block_idx, end_block_idx;\n        bool is_no_split;\n\n        // The following fields are only valid for MODEL1\n        int topk_length, extra_topk_length, num_orig_kv_blocks;\n    };\n    auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs {\n        MainloopArgs args;\n        int total_topk_padded;\n        if constexpr (MODEL_TYPE == ModelType::V32) {\n            total_topk_padded = params.topk;\n        } else {\n            int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;\n            int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE);\n            int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;\n            total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE);\n            args.topk_length = topk_length;\n            args.extra_topk_length = extra_topk_length;\n            args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE;\n        }\n\n        args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;\n        args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE;\n        args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);\n\n        return args;\n    };\n\n    if (warpgroup_idx == 0) {\n        cutlass::arch::warpgroup_reg_alloc<192>();\n\n        TiledMMA tiled_mma_QK = TiledMMA_QK{};\n        ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup);\n        TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{};\n        ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);\n        \n        float rL[2], rM[2];\n        Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});\n        Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});\n        Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}));\n\n        float rAttn_sink[2] = {-CUDART_INF_F, -CUDART_INF_F};\n        if (params.attn_sink != nullptr) {\n            for (int i = 0; i < 2; ++i) {\n                int head_idx = head_block_idx*BLOCK_M + get_AorC_row_idx(i, idx_in_warpgroup);\n                rAttn_sink[i] = __ldg((float*)params.attn_sink + head_idx) * CUDART_L2E_F;\n            }\n        }\n\n        #pragma unroll 1\n        for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {\n            MainloopArgs args = get_cur_req_info(batch_idx);\n\n            rL[0] = rL[1] = 0.0f;\n            rM[0] = rM[1] = MAX_INIT_VAL;\n            cute::fill(rO, 0.);\n\n            // Wait for Q\n            plan.bar_q.wait((sched_meta.begin_req_idx-batch_idx)&1);\n\n            CUTE_NO_UNROLL\n            for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {\n                int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;\n                Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});\n                Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{});\n\n                // Wait, issue WGMMA\n                plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);\n                if constexpr (CLUSTER_SIZE == 2) {\n                    plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);\n                }\n\n                gemm<true, -1>(\n                    tiled_mma_QK,\n                    thr_mma_QK.partition_fragment_A(sQ),\n                    thr_mma_QK.partition_fragment_B(sK),\n                    rP\n                );\n\n                bar_phase_k ^= 1<<buf_idx;\n\n                cute::warpgroup_wait<0>();\n                \n                // Calculate S = softmax(mask(scale(P)))\n                if (block_idx != args.start_block_idx)\n                    NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free);  // Make sure that sScale and sS is free\n\n                // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks\n                scale_softmax(rP, rS, rO, params.sm_scale_div_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup);\n\n                // Store S into shared, inform warpgroup 1\n                save_rPb_to_sP(rS, sS, idx_in_warpgroup);\n                fence_view_async_shared();\n\n                // Issue O += S @ V\n                gemm<false, -1>(\n                    tiled_mma_PV,\n                    rS,\n                    thr_mma_PV.partition_fragment_B(sV),\n                    rO\n                );\n\n                NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready);\n\n                cute::warpgroup_wait<0>();\n\n                if constexpr (CLUSTER_SIZE == 2) {\n                    plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);\n                    plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);\n                } else {\n                    plan.bar_k_avail[buf_idx].arrive();\n                }\n            }\n\n            // Copy the next q\n            if (threadIdx.x/32 == 0 && elect_one_sync()) {\n                if (batch_idx != sched_meta.end_req_idx) {\n                    Tensor gQ = flat_divide(\n                        tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1),\n                        Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}\n                    )(_, _, head_block_idx, _0{});\n                    launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);\n                    plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));\n                } else {\n                    // This kernel is followed by the combine kernel, so we signal PDL here\n                    cudaTriggerProgrammaticLaunchCompletion();\n                }\n            }\n\n            // Synchronize L and M across warpgroups\n            rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);\n            rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);\n            rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);\n            rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);\n\n            if (idx_in_warpgroup%4 == 0) {\n                CUTE_UNROLL\n                for (int i = 0; i < 2; ++i) {\n                    int row = get_AorC_row_idx(i, idx_in_warpgroup);\n                    sL[row] = rL[i];\n                    sM[row] = rM[i];\n                }\n            }\n            \n            float o_scales[2];\n            CUTE_UNROLL\n            for (int i = 0; i < 2; ++i) {\n                if (args.is_no_split) {\n                    o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i] + exp2f(rAttn_sink[i] - rM[i]));\n                } else {\n                    o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i]);\n                }\n                if (idx_in_warpgroup%4 == 0) {\n                    int row = get_AorC_row_idx(i, idx_in_warpgroup);\n                    plan.sOScale[row] = o_scales[i];\n                }\n            }\n\n            // This is a synchronization point for warpgroup 0/1.\n            // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free\n            // Warpgroup 1 should wait wg 0 for sL to be ready\n            NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);\n\n            CUTE_UNROLL\n            for (int i = 0; i < 2; ++i)\n                rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];\n            \n            int start_head_idx = head_block_idx*BLOCK_M;\n            int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);\n            if (args.is_no_split) {\n                bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q;\t// (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)\n                Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(\n                    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},\n                    make_stride(params.stride_o_h_q, _1{})\n                ));\n                float* gSoftmaxLse = (float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + start_head_idx;\t// (BLOCK_M) : (1)\n\n                store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n                int i = threadIdx.x;\n                if (i < num_valid_seq_q) {\n                    float cur_L = sL[i];\n                    gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E;\n                }\n\n                cute::tma_store_wait<0>();\n            } else {\n                int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;\n                int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;\n                float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q;\t// (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)\n                float* gSoftmaxLseAccum = (float*)params.lse_accum + split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + start_head_idx;\t// (BLOCK_M) : (1)\n                Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(\n                    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},\n                    make_stride(params.stride_o_accum_h_q, _1{})\n                ));\n                store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n                int i = threadIdx.x;\n                if (i < num_valid_seq_q) {\n                    float cur_L = sL[i];\n                    gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i];\n                }\n\n                cute::tma_store_wait<0>();\n            }\n            \n            sync_all_threads_in_cluster();\n        }\n    } else if (warpgroup_idx == 1) {\n        cutlass::arch::warpgroup_reg_dealloc<160>();\n\n        TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{};\n        ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);\n        Tensor rO = partition_fragment_C(tiled_mma_PV, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});\n\n        #pragma unroll 1\n        for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {\n            MainloopArgs args = get_cur_req_info(batch_idx);\n            cute::fill(rO, 0.);\n\n            CUTE_NO_UNROLL\n            for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {\n                int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;\n                Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{});\n\n                // Wait for S and sScale\n                NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready);\n\n                // Scale O\n                float cur_scales[2];\n                *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2);\n                CUTE_UNROLL\n                for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {\n                    Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));\n                    CUTE_UNROLL\n                    for (int i = 0; i < size(cur_rO); ++i) {\n                        cur_rO(i) *= cur_scales[local_row_idx];\n                    }\n                }\n                \n                // Issue O += S @ V, and wait\n                gemm<false, -1>(\n                    tiled_mma_PV,\n                    thr_mma_PV.partition_fragment_A(sS),\n                    thr_mma_PV.partition_fragment_B(sV),\n                    rO\n                );\n                cute::warpgroup_wait<0>();\n                \n                if constexpr (CLUSTER_SIZE == 2) {\n                    plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);\n                    plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);\n                } else {\n                    plan.bar_k_avail[buf_idx].arrive();\n                }\n                \n                if (block_idx != args.end_block_idx-1)\n                    NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free);   // Tell WG0 that sScale and sS are available\n            }\n\n            NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);\n\n            float o_scales[2];\n            CUTE_UNROLL\n            for (int i = 0; i < 2; ++i) {\n                int row = get_AorC_row_idx(i, idx_in_warpgroup);\n                o_scales[i] = plan.sOScale[row];\n            }\n                \n            int start_head_idx = head_block_idx*BLOCK_M;\n            int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);\n            if (args.is_no_split) {\n                bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q;\t// (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)\n                Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(\n                    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},\n                    make_stride(params.stride_o_h_q, _1{})\n                ));\n\n                store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n                cute::tma_store_wait<0>();\n            } else {\n                int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;\n                int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;\n                float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q;\t// (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)\n                Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(\n                    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},\n                    make_stride(params.stride_o_accum_h_q, _1{})\n                ));\n                store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);\n\n                cute::tma_store_wait<0>();\n            }\n\n            sync_all_threads_in_cluster();\n        }\n    } else {\n        // Producer warpgroup\n        cutlass::arch::warpgroup_reg_dealloc<152>();\n\n        static_assert(CLUSTER_SIZE == 1 || CLUSTER_SIZE == 2);\n        static constexpr int NUM_TOKENS_PER_THREAD = CLUSTER_SIZE == 1 ? 2 : 1;\n        static constexpr int NUM_TOKENS_PER_ROUND = 32; // If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds)\n        int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0);\n        int lane_idx = idx_in_warpgroup % 32;\n        int my_token_idx_base = warp_idx*8 + lane_idx%8;\n        \n        CUTE_NO_UNROLL\n        for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {\n            MainloopArgs args = get_cur_req_info(batch_idx);\n            int* gIndices = params.indices + batch_idx*params.stride_indices_b + s_q_idx*params.stride_indices_s_q; // (topk) : (1)\n            int* gExtraIndices = params.extra_indices + batch_idx*params.stride_extra_indices_b + s_q_idx*params.stride_extra_indices_s_q; // (extra_topk) : (1)\n            \n            int nxt_token_indexs[NUM_TOKENS_PER_THREAD];\n            CUTE_UNROLL\n            for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {\n                if (MODEL_TYPE == ModelType::V32 || args.start_block_idx < args.num_orig_kv_blocks)\n                    nxt_token_indexs[round] = __ldg(gIndices + args.start_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + round*NUM_TOKENS_PER_ROUND + my_token_idx_base);\n            }\n\n            struct IsOrigBlock {};\n            struct IsExtraBlock {};\n\n            struct IsFirstExtraBlock {};\n            struct IsNotFirstExtraBlock {};\n            auto process_one_block = [&](int block_idx, auto is_extra_block_t, auto is_first_extra_block_t) {\n                static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;\n                static constexpr bool IS_FIRST_EXTRA_BLOCK = std::is_same_v<decltype(is_first_extra_block_t), IsFirstExtraBlock>;\n                int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;\n\n                int* indices_base;\n                int page_block_size;\n                int64_t k_block_stride, k_row_stride;\n                fp8* k_ptr;\n                if constexpr (!IS_EXTRA_BLOCK) {\n                    indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE;\n                    page_block_size = params.page_block_size;\n                    k_block_stride = params.stride_kv_block;\n                    k_row_stride = params.stride_kv_row;\n                    k_ptr = (fp8*)params.kv;\n                } else {\n                    indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE;\n                    page_block_size = params.extra_page_block_size;\n                    k_block_stride = params.stride_extra_kv_block;\n                    k_row_stride = params.stride_extra_kv_row;\n                    k_ptr = (fp8*)params.extra_kv;\n                }\n                [[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length;\n                [[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx;\n                transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx]));\n\n                CUTE_UNROLL\n                for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {\n                    int my_token_idx = my_token_idx_base + round*NUM_TOKENS_PER_ROUND;\n                    bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE;\n                    bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base);\n\n                    // Get prefetched token index\n                    int token_index;\n                    if constexpr (!IS_EXTRA_BLOCK) {\n                        token_index = nxt_token_indexs[round];\n                        if (block_idx+1 != (MODEL_TYPE == ModelType::V32 ? args.end_block_idx : args.num_orig_kv_blocks))\n                            nxt_token_indexs[round] = __ldg(gIndices + (block_idx+1)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);\n                    } else {\n                        if constexpr (IS_FIRST_EXTRA_BLOCK) {\n                            token_index = __ldg(gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);\n                        } else {\n                            token_index = nxt_token_indexs[round];\n                        }\n                        if (block_idx+1 != args.end_block_idx)\n                            nxt_token_indexs[round] = __ldg(gExtraIndices + (block_idx+1-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);\n                    }\n                    \n                    if constexpr (MODEL_TYPE == ModelType::MODEL1) {\n                        // For MODEL1, we need to check whether the token_index is within topk_length\n                        if (rel_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx >= topk_length) {\n                            token_index = -1;   // To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length\n                        }\n                    }\n\n                    int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size);   // Use uint32_t division and mod to improve performance\n                    int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size;   // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error\n\n                    fp8* gK_base;\n                    bf16 scales[NUM_SCALES];\n                    if constexpr (MODEL_TYPE == ModelType::V32) {\n                        static_assert(NUM_SCALES == 4);\n                        gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*k_row_stride;\n                        float scales_float[NUM_SCALES];\n                        *(float4*)(scales_float) = load_128b_from_gmem<float4, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>((float*)(gK_base+HEAD_DIM_NOPE));\n                        CUTE_UNROLL\n                        for (int i = 0; i < NUM_SCALES; ++i) {\n                            scales[i] = (bf16)scales_float[i];\n                        }\n                    } else {\n                        static_assert(NUM_SCALES == 8);\n                        gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*sizeof(bf16));\n                        fp8_e8m0* gK_scales_base = (fp8_e8m0*)(k_ptr + block_index*k_block_stride + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*sizeof(bf16)) + rel_idx_in_block*NUM_SCALES*sizeof(fp8_e8m0));\n                        fp8_e8m0 scales_e8m0[NUM_SCALES];\n                        *(int64_t*)scales_e8m0 = __ldg((int64_t*)gK_scales_base);\n                        CUTE_UNROLL\n                        for (int i = 0; i < NUM_SCALES; i += 2) {\n                            *(__nv_bfloat162_raw*)(scales+i) = __nv_cvt_e8m0x2_to_bf162raw(*(__nv_fp8x2_storage_t*)(scales_e8m0+i));\n                        }\n                    }\n\n                    // Wait for the nope buffer to be available\n                    if (round == 0) {\n                        plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1);\n                    }\n                    \n                    if (CLUSTER_SIZE == 2 && round == 0 && idx_in_warpgroup == 0) {\n                        plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16));\n                    }\n\n                    // Collectively copy from global memory and dequant\n                    // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py\n                    \n                    fp8* gK_nope = gK_base + (lane_idx/8)*16;\n                    if (token_index == -1) {\n                        CUTE_UNROLL\n                        for (int i = 0; i < NUM_SCALES; ++i)\n                            scales[i] = (bf16)0.0f;\n                    }\n                    CUTE_UNROLL\n                    for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) {\n                        fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16, L1CacheHint::EVICT_LAST, L2PrefetchHint::B256>(gK_nope + dim_idx*64);   // We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1)\n                        bf16 scale = scales[MODEL_TYPE == ModelType::V32 ? dim_idx/2 : dim_idx];\n                        auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) {\n                            int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE;\n                            bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, __bfloat162bfloat162(*(__nv_bfloat16*)(&scale)));\n                            *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;\n                            if constexpr (CLUSTER_SIZE == 2) {\n                                st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);\n                            }\n                        };\n                        if (token_index == -1)\n                            *(uint128_t*)(&cur_fp8x16) = uint128_t();\n                        dequant_and_save_bf16x8(cur_fp8x16.lo, 0);\n                        dequant_and_save_bf16x8(cur_fp8x16.hi, 8);\n                    }\n\n                    bf16* gK_rope;\n                    if constexpr (MODEL_TYPE == ModelType::V32) {\n                        gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8;\n                    } else {\n                        gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE) + (lane_idx/8)*8;\n                    }\n                    bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE;\n                    bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base);\n\n                    CUTE_UNROLL\n                    for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) {\n                        bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>(gK_rope + dim_idx*32);\n                        if constexpr (MODEL_TYPE == ModelType::V32) {\n                            // NOTE We do not need to mask the RoPE part for V3.2 since it isn't involved in the SV gemm\n                        } else {\n                            if (token_index == -1)\n                                *(uint128_t*)(&cur_bf16x8) = uint128_t();\n                        }\n                        int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE;\n                        *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;\n                        if constexpr (CLUSTER_SIZE == 2) {\n                            st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);\n                        }\n                    }\n                }\n\n                fence_view_async_shared();\n\n                if (idx_in_warpgroup < 32) {\n                    // We put this after fence_view_async_shared() since this won't be read by async proxy\n                    auto is_index_valid = [&](int index, int offset_within_thread) -> bool {\n                        if constexpr (MODEL_TYPE == ModelType::V32) {\n                            return index != -1;\n                        } else {\n                            return index != -1 && rel_block_idx*TOPK_BLOCK_SIZE + lane_idx*2 + offset_within_thread < topk_length;\n                        }\n                    };\n                    int2 indices = __ldg((int2*)(indices_base + lane_idx*2));\n                    *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {\n                        is_index_valid(indices.x, 0),\n                        is_index_valid(indices.y, 1)\n                    };\n                }\n\n                // Signal the barrier\n                plan.bar_k_local_ready[buf_idx].arrive();\n                bar_phase_k ^= 1 << buf_idx;\n            };\n\n            if constexpr (MODEL_TYPE == ModelType::V32) {\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {\n                    process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});\n                }\n            } else {\n                CUTE_NO_UNROLL\n                for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {\n                    process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});\n                }\n\n                if (args.num_orig_kv_blocks < args.end_block_idx) {\n                    process_one_block(max(args.start_block_idx, args.num_orig_kv_blocks), IsExtraBlock{}, IsFirstExtraBlock{});\n                }\n                CUTE_NO_UNROLL\n                for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks)+1; block_idx < args.end_block_idx; ++block_idx) {\n                    process_one_block(block_idx, IsExtraBlock{}, IsNotFirstExtraBlock{});\n                }\n            }\n\n            sync_all_threads_in_cluster();\n        }\n    }\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm90\");\n    }\n#endif\n\n}\n\ntemplate<typename Kernel, typename TMAParams>\n__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, Kernel::CLUSTER_SIZE)\nflash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TMAParams tma_params) {\n    Kernel::devfunc(params, tma_params);\n}\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\nvoid KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &params) {\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0);\n    KU_ASSERT(params.d_qk == HEAD_DIM_K);\n    KU_ASSERT(params.d_v == HEAD_DIM_V);\n    KU_ASSERT(params.h_q % BLOCK_M == 0);\n    if constexpr (MODEL_TYPE == ModelType::MODEL1) {\n        constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8;\n        KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, \"Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1\");  // Each block must be contiguous\n        if (params.extra_kv != nullptr) {\n            KU_ASSERT(params.stride_extra_kv_row == BYTES_PER_TOKEN, \"Each page block in extra KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1\");  // Each block must be contiguous\n        }\n    } else {\n        KU_ASSERT(params.extra_kv == nullptr, \"V3.2 does not support extra KV cache\");\n        KU_ASSERT(params.topk_length == nullptr, \"V3.2 does not support dynamic topk length\");\n        KU_ASSERT(params.stride_kv_row == 656);  // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16)\n    }\n\n    auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q, params.b);\n    auto tma_Q = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q),\n            make_layout(\n                shape_Q,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)\n            )\n        ),\n        SmemLayoutQ{}\n    );\n    \n    CUtensorMap tensor_map_o;\n    {\n        // Here we manually construct TMA descriptor to store O, in order to leverage 5D TMA\n        uint64_t size[5] = {OBUF_SW, (unsigned long)params.h_q, HEAD_DIM_V/OBUF_SW, (unsigned long)params.s_q, (unsigned long)params.b};\n        uint64_t stride[4] = {params.stride_o_h_q*sizeof(bf16), OBUF_SW*sizeof(bf16), params.stride_o_s_q*sizeof(bf16), params.stride_o_b*sizeof(bf16)};\n        uint32_t box_size[5] = {OBUF_SW, BLOCK_M, HEAD_DIM_V/OBUF_SW, 1, 1};\n        uint32_t elem_stride[5] = {1, 1, 1, 1, 1};\n        CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(\n            &tensor_map_o,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            5,\n            params.out,\n            size,\n            stride,\n            box_size,\n            elem_stride,\n            CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n            OBUF_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B :\n                OBUF_SW == 32 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B :\n                OBUF_SW == 16 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B :\n                CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,\n            CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE\n        );\n        KU_ASSERT(res == CUresult::CUDA_SUCCESS);\n    }\n\n    TmaParams<\n        decltype(shape_Q), decltype(tma_Q)\n    > tma_params = {\n        shape_Q, tma_Q,\n        tensor_map_o\n    };\n    auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE, NUM_HEADS>, decltype(tma_params)>;\n\n    constexpr size_t smem_size = sizeof(SharedMemoryPlan);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    // NOTE Don't use PDL because of potential compiler bugs!\n    // cudaLaunchAttribute mla_kernel_attributes[1];\n    // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;\n    // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1;\n    // cudaLaunchConfig_t mla_kernel_config = {\n    //     dim3(num_m_block, params.h_k, params.num_sm_parts),\n    //     dim3(NUM_THREADS, 1, 1),\n    //     smem_size,\n    //     stream,\n    //     mla_kernel_attributes,\n    //     1\n    // };\n    // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);\n    cutlass::ClusterLaunchParams launch_params = {\n        dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts),\n        dim3(NUM_THREADS, 1, 1),\n        dim3(CLUSTER_SIZE, 1, 1),\n        smem_size,\n        params.stream\n    };\n    cutlass::launch_kernel_on_cluster(\n        launch_params, (void*)mla_kernel, params, tma_params\n    );\n    KU_CHECK_KERNEL_LAUNCH();\n}\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\nvoid run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params) {\n    KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(params);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/decode/sparse_fp8/splitkv_mla.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\nvoid run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);\n\n}\n\n"
  },
  {
    "path": "csrc/sm90/helpers.h",
    "content": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cutlass/arch/barrier.h>\n\nnamespace sm90 {\n\n__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);\n    asm volatile(\"cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\\n\"\n        :: \"r\"(dst_addr),\n           \"l\"(src),\n           \"n\"(16));\n}\n\n__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) {\n    uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);\n    asm volatile(\"cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\\n\"\n        :: \"r\"(dst_addr),\n           \"l\"(src),\n           \"r\"(pred?16:0),\n           \"l\"(cache_policy));\n}\n\n__forceinline__ __device__ int64_t createpolicy_evict_last() {\n    int64_t res;\n    asm volatile(\n        \"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \\n\\t\"\n        : \"=l\"(res)\n        :\n    );\n    return res;\n}\n\n__forceinline__ __device__ int64_t createpolicy_evict_first() {\n    int64_t res;\n    asm volatile(\n        \"createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \\n\\t\"\n        : \"=l\"(res)\n        :\n    );\n    return res;\n}\n\n\n__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {\n    // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx\n    // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a\n    int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);\n    return row_idx;\n}\n\n__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {\n    int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);\n    return col_idx;\n}\n\n// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h\n// * Copyright (c) 2024, Tri Dao.\ntemplate <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\n__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {\n    using namespace cute;\n    constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;\n    // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const\n    if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (arrive) {\n        warpgroup_arrive();\n    }\n    if constexpr (zero_init) {\n        tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    } else {\n        // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n    }\n    if constexpr (commit) {\n        warpgroup_commit_batch();\n    }\n    if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }\n    warpgroup_fence_operand(tCrC);\n    if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }\n}\n\n// A simpler version of gemm\ntemplate <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\n__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {\n    using namespace cute;\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor sA_frag = thr_mma.partition_fragment_A(sA);\n    Tensor sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(sA_frag) == size<2>(sB_frag));\n\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_arrive();\n    tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k = 0; k < size<2>(sA_frag); ++k) {\n        cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);\n        tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    }\n    warpgroup_fence_operand(rC_frag);\n}\n\ntemplate <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\n__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {\n    using namespace cute;\n    ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);\n    Tensor sB_frag = thr_mma.partition_fragment_B(sB);\n    static_assert(size<2>(rA_frag) == size<2>(sB_frag));\n\n    warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_arrive();\n    tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;\n    CUTLASS_PRAGMA_UNROLL\n    for (int k = 0; k < size<2>(rA_frag); ++k) {\n        cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);\n        tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n    }\n    warpgroup_fence_operand(rC_frag);\n    warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));\n}\n\n\n__forceinline__ __device__ uint32_t get_sm_id() {\n    uint32_t ret;\n    asm(\"mov.u32 %0, %%smid;\" : \"=r\"(ret));\n    return ret;\n}\n\nstatic constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs.\ntemplate<typename T>\nCUTE_DEVICE\nT* get_peer_addr(const T* p) {\n    return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);\n}\n\ntemplate<\n    typename TMA,\n    typename Tensor0,\n    typename Tensor1\n>\nCUTE_DEVICE\nvoid launch_tma_copy(\n    const TMA &tma_copy,\n    Tensor0 src,\n    Tensor1 dst,\n    cutlass::arch::ClusterTransactionBarrier &bar,\n    const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL\n) {\n    auto thr_tma = tma_copy.get_slice(cute::_0{});\n    cute::copy(\n        tma_copy.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(bar), 0, cache_hint),\n        thr_tma.partition_S(src),\n        thr_tma.partition_D(dst)\n    );\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/config.h",
    "content": "#pragma once\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launch.hpp>\n#include <cooperative_groups.h>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/arch/arch.h>\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n#include \"params.h\"\n\nnamespace sm90::fwd {\n\nusing namespace cute;\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\nclass KernelTemplate {\npublic:\n\nstatic constexpr int D_Q = D_QK;\nstatic constexpr int D_K = D_QK;\nstatic constexpr int D_V = 512;\n\nstatic constexpr int B_H = 64;\nstatic constexpr int B_TOPK = 64;    // TopK block size\nstatic constexpr int NUM_THREADS = 128*3;\nstatic constexpr float MAX_INIT_VAL = -1e30;    // We use this number as the initial value for mi (max logits)\n\ntemplate<int NUM_TILES>\nusing SmemLayoutQTiles = decltype(coalesce(tile_to_shape(\n    GMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutOTiles = decltype(coalesce(tile_to_shape(\n    GMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTiles = decltype(coalesce(tile_to_shape(\n    GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},\n    Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},\n    Step<_1, _2>{}\n), Shape<_1, _1>{}));\n\ntemplate<int NUM_TILES>\nusing SmemLayoutKTilesTransposed = decltype(composition(\n\tSmemLayoutKTiles<NUM_TILES>{},\n\tLayout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}\n));\n\nusing SmemLayoutQ = SmemLayoutQTiles<D_Q/64>;\nusing SmemLayoutO = SmemLayoutOTiles<D_V/64>;\nusing SmemLayoutK = SmemLayoutKTiles<D_Q/64>;\nusing SmemLayoutV = SmemLayoutKTilesTransposed<D_V/64>;\nusing SmemLayoutHalfV = SmemLayoutKTilesTransposed<D_V/64/2>;\n\nusing SmemLayoutS = decltype(coalesce(tile_to_shape(\n    GMMA::Layout_K_SW128_Atom<bf16>{},\n    Shape<Int<B_H>, Int<B_TOPK>>{}\n), Shape<_1, _1>{}));\n\nstruct SharedMemoryPlan {\n    union {\n        array_aligned<bf16, cosize_v<SmemLayoutQ>> q;\n        array_aligned<bf16, cosize_v<SmemLayoutO>> o;\n    } q_o;\n    array_aligned<bf16, cosize_v<SmemLayoutK>> k[2];\n    array_aligned<bf16, cosize_v<SmemLayoutS>> s[D_QK == 576 ? 1 : 2];  // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers\n\n    bool is_kv_valid[2][B_TOPK];\n    float2 sM[32];\n    float2 sL[64];   // For reduction across WG0/1 in epilogue\n    float final_max_logits[64], final_lse[64];\n    transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;\n};\n\nusing TiledMMA_QK = decltype(make_tiled_mma(\n    GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\nusing TiledMMA_PV_LocalP = decltype(make_tiled_mma(\n    GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\nusing TiledMMA_PV_RemoteP = decltype(make_tiled_mma(\n    GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},\n    Layout<Shape<_1, _1, _1>>{}\n));\n\ntemplate<\n    typename Shape_Q, typename TMA_Q\n>\nstruct TmaParams {\n    Shape_Q shape_Q; TMA_Q tma_Q;\n    CUtensorMap tensor_map_O;\n};\n\nenum NamedBarriers : uint32_t {\n    wg0_bunch_0_ready = 0,\n    wg1_bunch_0_ready = 1,\n    wg0_s0_ready = 2,\n    wg1_s1_ready = 3,\n    sL_ready = 4,\n    warpgroup0_sync = 5,\n    warpgroup1_sync = 6,\n    epilogue_sync = 7\n};\n\n// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction\ntemplate<\n    typename Tensor0,\n    typename Tensor1\n>\nstatic __forceinline__ __device__ void save_rS_to_sS(\n    Tensor0 const &rPb,\n    Tensor1 const &sP,\n    int idx_in_warpgroup\n) {\n    auto r2s_copy = make_tiled_copy_C(\n        Copy_Atom<SM90_U32x4_STSM_N, bf16>{},\n        TiledMMA_QK{}\n    );\n    ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);\n    Tensor thr_copy_rPb = thr_copy.retile_S(rPb);\n    Tensor thr_copy_sP = thr_copy.partition_D(sP);\n    cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);\n}\n\ntemplate<typename TMAParams>\nstatic __device__ __forceinline__ void\ndevfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params);\n\nstatic void run(const SparseAttnFwdParams &params);\n\n};\n\n\n};\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/fwd.cu",
    "content": "#include \"fwd.h\"\n\n#include <stdexcept>\n\n#include \"phase1.h\"\n\nnamespace sm90 {\n\nvoid run_fwd_kernel(const SparseAttnFwdParams& params) {\n    const bool have_topk_length = params.topk_length != nullptr;\n\n    // Dispatch based on d_qk dimension and presence of topk_length\n    if (params.d_qk == 512) {\n        if (have_topk_length) {\n            sm90::fwd::run_fwd_phase1_kernel<512, true>(params);\n        } else {\n            sm90::fwd::run_fwd_phase1_kernel<512, false>(params);\n        }\n    } else if (params.d_qk == 576) {\n        if (have_topk_length) {\n            sm90::fwd::run_fwd_phase1_kernel<576, true>(params);\n        } else {\n            sm90::fwd::run_fwd_phase1_kernel<576, false>(params);\n        }\n    } else {\n        throw std::runtime_error(\"Unsupported d_qk value in sparse attention fwd kernel\");\n    }\n}\n\n}  // namespace sm90\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/fwd.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90 {\n\nvoid run_fwd_kernel(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\n// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH\n// = true / false respectively, to compile them in parallel.\ntemplate void run_fwd_phase1_kernel<512, false>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\n// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH\n// = true / false respectively, to compile them in parallel.\ntemplate void run_fwd_phase1_kernel<512, true>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\ntemplate void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",
    "content": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\ntemplate void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/phase1.cuh",
    "content": "#pragma once\n\n#include \"config.h\"\n\n#include \"utils.h\"\n#include \"../../helpers.h\"\n\nnamespace sm90::fwd {\n\nusing namespace cute;\n\nCUTE_DEVICE void st_global_cs_128(float f0, float f1, float f2, float f3, void *dst_ptr) {\n    asm volatile(\"st.weak.global.cs.v4.f32 [%0], {%1, %2, %3, %4};\\n\"\n                 :\n                 : \"l\"(dst_ptr),\n                   \"f\"(f0), \"f\"(f1), \"f\"(f2), \"f\"(f3)\n                );\n}\n\nCUTE_DEVICE\nfloat2 __shfl_xor_sync_float2(\n    uint32_t mask, float2 value, int offset\n) {\n    float2 res;\n    *reinterpret_cast<long long*>(&res) = __shfl_xor_sync(\n        mask,\n        *reinterpret_cast<long long*>(&value),\n        offset\n    );\n    return res;\n}\n\nCUTE_DEVICE\nvoid tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(src_ptr);\n    asm volatile(\"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\\n\"\n                     :\n                     : \"l\"(dst_ptr), \"r\"(smem_int_ptr), \"r\"(store_bytes)\n                     : \"memory\");\n}\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\ntemplate<typename TMAParams>\n__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params) {\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))\n    const int q_h_idx = blockIdx.x % (params.h_q/B_H);\n    const int s_q_idx = blockIdx.x / (params.h_q/B_H);\n    const int warpgroup_idx = cutlass::canonical_warp_group_idx();\n    const int warp_idx = cutlass::canonical_warp_idx_sync();\n    const int idx_in_warpgroup = threadIdx.x % 128;\n\n    // Define shared tensors\n    extern __shared__ char wksp_buf[];\n    SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);\n    Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{});\n    Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{});\n    Tensor sS0 = make_tensor(make_smem_ptr(D_QK == 576 ? plan.k[0].data()+64*512 : plan.s[1].data()), SmemLayoutS{});    // Overlap with sK0's RoPE part for V3.2\n    Tensor sS1 = make_tensor(make_smem_ptr(plan.s[0].data()), SmemLayoutS{});\n\n    if (warp_idx == 0 && elect_one_sync()) {\n        // Prefetch TMA descriptors\n        cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(&tma_params.tensor_map_O);\n\n        // Initialize barriers\n        plan.bar_q.init(1);\n        CUTE_UNROLL\n        for (int i = 0; i < 2; ++i) {\n            plan.bar_k0_free[i].init(128);\n            plan.bar_k0_ready[i].init(128);\n            plan.bar_k1_free[i].init(128);\n            plan.bar_k1_ready[i].init(128);\n        }\n        plan.bar_is_kv_valid_ready.init(16);\n        fence_barrier_init();\n    }\n\n    __syncthreads();\n    \n    const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk;\n    const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);\n\n    if (warpgroup_idx == 0 || warpgroup_idx == 1) {\n        cutlass::arch::warpgroup_reg_alloc<216>();\n\n        if (warp_idx == 0 && elect_one_sync()) {\n            // Load Q\n            Tensor gQ = flat_divide(\n                tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),\n                Tile<Int<B_H>, Int<D_Q>>{}\n            )(_, _, q_h_idx, _0{});\n            launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);\n            plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));\n        }\n\n        float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation\n        float rL[2] = {0.0f, 0.0f};\n        Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<D_V/2>>{});\n        Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<B_H>, Int<B_TOPK>>{});\n        Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<B_TOPK>>{}));\n        cute::fill(rO, 0.0f);\n        \n        // Wait for Q\n        plan.bar_q.wait(0);\n\n        bool cur_bar_wait_phase = 0;\n        \n        struct Warpgroup0 {};\n        struct Warpgroup1 {};\n\n        auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) {\n            constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;\n            TiledMMA tiled_mma_QK = TiledMMA_QK{};\n            Tensor sQ_tile = flat_divide(sQ, Tile<Int<B_H>, Int<64>>{})(_, _, _0{}, tile_idx);\n            Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{});\n            gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup);\n        };\n\n        auto mask_rP = [&](auto warpgroup_idx) {\n            constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;\n            plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);\n            CUTE_UNROLL\n            for (int row_idx = 0; row_idx < 2; ++row_idx) {\n                CUTE_UNROLL\n                for (int i = row_idx*2; i < size(rP); i += 4) {\n                    int col = 8*(i/4) + (idx_in_warpgroup%4)*2;\n                    if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY;\n                    if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY;\n                }\n            }\n        };\n\n        auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) {\n            plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);\n            constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;\n            const float scale = params.sm_scale_div_log2;\n            float r_sM[2];\n            if constexpr (IS_WG1) {\n                *(float2*)r_sM = plan.sM[idx_in_warpgroup/4];\n            }\n            float new_maxs[2];\n            CUTE_UNROLL\n            for (int row_idx = 0; row_idx < 2; ++row_idx) {\n                // Get rowwise max\n                float cur_max = -INFINITY;\n                CUTE_UNROLL\n                for (int i = row_idx*2; i < size(rP); i += 4) {\n                    cur_max = max(cur_max, max(rP(i), rP(i+1)));\n                }\n                cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));\n                cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));\n                cur_max *= scale;\n\n                // Get new max and scale\n                // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round)\n                new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max);\n\n                // Scale O\n                float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]);\n                CUTE_UNROLL\n                for (int i = row_idx*2; i < size(rO); i += 4) {\n                    rO(i) *= scale_for_o;\n                    rO(i+1) *= scale_for_o;\n                }\n\n                // Get rS\n                float cur_sum = 0;\n                CUTE_UNROLL\n                for (int i = row_idx*2; i < size(rP); i += 4) {\n                    rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]);\n                    rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]);\n                    rS(i) = (bf16)rP(i);\n                    rS(i+1) = (bf16)rP(i+1);\n                    cur_sum += rP(i) + rP(i+1);\n                }\n                rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum;\n            }\n            __syncwarp();\n            if (idx_in_warpgroup%4 == 0) {\n                plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs;\n            }\n            rM[0] = new_maxs[0];\n            rM[1] = new_maxs[1];\n        };\n\n        auto reduce_L = [&]() {\n            // Reduce L\n            // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131\n            rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);\n            rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);\n            rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);\n            rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);\n            if (idx_in_warpgroup%4 == 0)\n                plan.sL[threadIdx.x/4] = *(float2*)(rL);\n            NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready);\n            float2 peer_L = plan.sL[(threadIdx.x/4)^32];\n            rL[0] += peer_L.x;\n            rL[1] += peer_L.y;\n        };\n\n        auto store_O = [&]() {\n            float scale_factors[2];\n            CUTE_UNROLL\n            for (int i = 0; i < 2; ++i) {\n                float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : params.attn_sink[q_h_idx*B_H + get_AorC_row_idx(i, idx_in_warpgroup)]*CUDART_L2E_F;\n                scale_factors[i] = 1.0f / (rL[i] + exp2f(attn_sink - rM[i]));\n                if (rL[i] == 0.0f)\n                    scale_factors[i] = 0.0f;    // The output should be 0 whatever attn_sink is\n            }\n\n            Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{});\n            bf16* stsm_addrs[4];\n            int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16);\n            CUTE_UNROLL\n            for (int i = 0; i < 64/16; ++i) {\n                stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i);\n            }\n            bool s2g_pred = warp_idx%4 == 0 && elect_one_sync();\n\n            warpgroup_wait<0>();\n            CUTE_UNROLL\n            for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) {\n                // Convert\n                constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128;   // 64: tile size, 128: warpgroup size\n                bf16 cur_rOb[NUM_ELEMS_EACH_TILE];\n                CUTE_UNROLL\n                for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) {\n                    cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]);\n                }\n                // R -> S\n                CUTE_UNROLL\n                for (int i = 0; i < 64/16; ++i) {\n                    SM90_U32x4_STSM_N::copy(\n                        *reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 0),\n                        *reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 2),\n                        *reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 4),\n                        *reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 6),\n                        *reinterpret_cast<uint128_t*>(stsm_addrs[i] + tile_idx*(B_H*64))\n                    );\n                }\n                fence_view_async_shared();\n                NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync);\n                // S -> G\n                if (s2g_pred) {\n                    int g_tile_idx = warpgroup_idx*4 + tile_idx;\n                    SM90_TMA_STORE_3D::copy(\n                        &tma_params.tensor_map_O,\n                        plan.q_o.o.data() + g_tile_idx*(B_H*64),\n                        g_tile_idx*64,\n                        q_h_idx*B_H,\n                        s_q_idx\n                    );\n                }\n            }\n            cute::tma_store_arrive();\n        };\n\n\n        if (warpgroup_idx == 0) {\n            // Warpgroup 0\n\n            auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) {\n                plan.bar_k0_ready[0].wait(cur_bar_wait_phase);\n                qkt_gemm_one_tile(Warpgroup0{}, 0, true);\n                qkt_gemm_one_tile(Warpgroup0{}, 1, false);\n                qkt_gemm_one_tile(Warpgroup0{}, 2, false);\n                qkt_gemm_one_tile(Warpgroup0{}, 3, false);\n                warpgroup_commit_batch();\n            };\n\n            auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) {\n                plan.bar_k0_ready[1].wait(cur_bar_wait_phase);\n                qkt_gemm_one_tile(Warpgroup0{}, 4, false);\n                qkt_gemm_one_tile(Warpgroup0{}, 5, false);\n                qkt_gemm_one_tile(Warpgroup0{}, 6, false);\n                qkt_gemm_one_tile(Warpgroup0{}, 7, false);\n                if constexpr (D_QK == 576) {\n                    qkt_gemm_one_tile(Warpgroup0{}, 8, false);\n                }\n                warpgroup_commit_batch();\n            };\n\n            auto scale_rS = [&](float scales[2]) {\n                CUTE_UNROLL\n                for (int row = 0; row < 2; ++row) {\n                    CUTE_UNROLL\n                    for (int i = row*2; i < size(rP); i += 4) {\n                        rS(i) = (bf16)(rP(i) * scales[row]);\n                        rS(i+1) = (bf16)(rP(i+1) * scales[row]);\n                    }\n                }\n            };\n\n            auto rescale_rO = [&](float scales[2]) {\n                CUTE_UNROLL\n                for (int row = 0; row < 2; ++row) {\n                    CUTE_UNROLL\n                    for (int i = row*2; i < size(rO); i += 4) {\n                        rO(i) *= scales[row];\n                        rO(i+1) *= scales[row];\n                    }\n                    rL[row] *= scales[row];\n                }\n            };\n            \n            CUTE_NO_UNROLL\n            for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {\n                Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{});\n                Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{});\n\n                if (block_idx == 0) {\n                    // NOTE: We put this code here to avoid register spilling\n                    pipelined_wait_and_qkt_gemm_l();\n                    pipelined_wait_and_qkt_gemm_r();\n                    warpgroup_wait<0>();\n                }\n                \n                // Online softmax, inform WG1\n                mask_rP(Warpgroup0{});\n                \n                \n                online_softmax_and_rescale_o(Warpgroup0{});\n                NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready);\n\n                // Issue rO0 += rS0 @ sV0l\n                gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup);\n                warpgroup_commit_batch();\n\n                // Mark V0L as free\n                warpgroup_wait<0>();\n                plan.bar_k0_free[0].arrive();\n\n                // Wait for new sM, scale rS, save, inform WG1\n                NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready);\n                float new_rM[2], scale_factors[2];\n                *(float2*)new_rM = plan.sM[idx_in_warpgroup/4];\n                CUTE_UNROLL\n                for (int i = 0; i < 2; ++i) {\n                    scale_factors[i] = exp2f(rM[i] - new_rM[i]);\n                    rM[i] = new_rM[i];\n                }\n                scale_rS(scale_factors);\n                save_rS_to_sS(rS, sS0, idx_in_warpgroup);\n                fence_view_async_shared();\n                NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready);\n\n                // Wait for sS1\n                NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready);\n\n                // Rescale rO0, Issue rO0 += sS1 @ sV1L\n                rescale_rO(scale_factors);\n                gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup);\n                warpgroup_commit_batch();\n\n                cur_bar_wait_phase ^= 1;\n\n                if (block_idx+2 < num_topk_blocks) {\n                    // Launch the next QK^T GEMM\n                    pipelined_wait_and_qkt_gemm_l();\n\n                    // Mark V1L as free\n                    warpgroup_wait<1>();\n                    plan.bar_k1_free[0].arrive();\n                    pipelined_wait_and_qkt_gemm_r();\n\n                    // Wait for rP0 = sQ @ sK0\n                    warpgroup_wait<0>();\n                } else {\n                    // Mark V1L as free\n                    warpgroup_wait<0>();\n                    plan.bar_k1_free[0].arrive();\n                }\n            }\n\n            reduce_L();\n            store_O();\n        } else {\n            // Warpgroup 1\n\n            auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) {\n                plan.bar_k1_ready[1].wait(cur_bar_wait_phase);\n                qkt_gemm_one_tile(Warpgroup1{}, 4, true);\n                qkt_gemm_one_tile(Warpgroup1{}, 5, false);\n                qkt_gemm_one_tile(Warpgroup1{}, 6, false);\n                qkt_gemm_one_tile(Warpgroup1{}, 7, false);\n                if constexpr (D_QK == 576) {\n                    qkt_gemm_one_tile(Warpgroup1{}, 8, false);\n                }\n                plan.bar_k1_ready[0].wait(cur_bar_wait_phase);\n                qkt_gemm_one_tile(Warpgroup1{}, 0, false);\n                qkt_gemm_one_tile(Warpgroup1{}, 1, false);\n                qkt_gemm_one_tile(Warpgroup1{}, 2, false);\n                qkt_gemm_one_tile(Warpgroup1{}, 3, false);\n                warpgroup_commit_batch();\n            };\n            \n            CUTE_NO_UNROLL\n            for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {\n                Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{});\n                Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{});\n\n                // Issue rP1 = sQ @ sK1, and wait\n                pipelined_wait_and_qkt_gemm();\n                warpgroup_wait<0>();\n\n                mask_rP(Warpgroup1{});\n\n\n                // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready)\n                NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready);\n                online_softmax_and_rescale_o(Warpgroup1{});\n                NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready);\n\n\n                // Issue rO1 += rS1 @ sV1R\n                gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup);\n                warpgroup_commit_batch();\n                \n                // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R\n                save_rS_to_sS(rS, sS1, idx_in_warpgroup);   // Put it here is faster\n                NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready);\n                gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup);\n                warpgroup_commit_batch();\n                \n                // Save rS1, inform WG0\n                fence_view_async_shared();\n                NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready);\n\n                // Wait for GEMM, and inform that sV1R is free\n                warpgroup_wait<1>();\n                plan.bar_k1_free[1].arrive();\n\n                // Wait for GEMM, and inform that sV0R is free\n                warpgroup_wait<0>();\n                plan.bar_k0_free[1].arrive();\n\n                cur_bar_wait_phase ^= 1;\n            }\n\n            reduce_L();\n            store_O();\n\n            // Save lse\n            if (idx_in_warpgroup%4 == 0) {\n                for (int row = 0; row < 2; ++row) {\n                    int real_row = get_AorC_row_idx(row, idx_in_warpgroup);\n                    bool is_no_valid_tokens = rL[row] == 0.0f;\n                    plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]*CUDART_LN2_F;\n                    plan.final_lse[real_row] = is_no_valid_tokens ? +INFINITY : logf(rL[row]) + rM[row]*CUDART_LN2_F;\n                }\n                fence_view_async_shared();\n            }\n\n            NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync);\n            if (idx_in_warpgroup == 0) {\n                int g_offset = s_q_idx*params.h_q + q_h_idx*B_H;\n                SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float));\n                SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float));\n                cute::tma_store_arrive();\n            }\n        }\n    } else {\n        // Producer warpgroup\n        cutlass::arch::warpgroup_reg_dealloc<72>();\n\n        constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE;\n        constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;\n        int idx_in_group = idx_in_warpgroup % GROUP_SIZE;\n        int group_idx = idx_in_warpgroup / GROUP_SIZE;\n        int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q;   // [topk]\n\n        bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8));\n        bf16* my_gKV_base = params.kv + idx_in_group*8;\n        \n        int64_t token_indices[2][NUM_ROWS_PER_GROUP];\n        bool is_token_valid[2][NUM_ROWS_PER_GROUP];\n        auto load_token_indices = [&](int block_idx) {\n            CUTE_UNROLL\n            for (int buf_idx = 0; buf_idx < 2; ++buf_idx) {\n                CUTE_UNROLL\n                for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {\n                    int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx;\n                    int t = __ldg(gIndices + offs);\n                    token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv;   // We mult it with params.stride_kv_s_kv here since it's faster\n                    bool is_cur_token_valid = t >= 0 && t < params.s_kv;\n                    if constexpr (HAVE_TOPK_LENGTH) {\n                        is_cur_token_valid &= offs < topk_length;\n                    }\n                    is_token_valid[buf_idx][local_row] = is_cur_token_valid;\n                }\n            }\n        };\n        \n        int64_t cache_policy = createpolicy_evict_last();\n        auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {\n            // Copy some K/V tiles from global memory to shared memory\n            // A tile has a shape of 64 (B_TOPK) x 64\n            // `buf_idx` is the index of the shared memory buffer, 0 or 1\n            // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8\n            CUTE_UNROLL\n            for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {\n                int64_t token_index = token_indices[buf_idx][local_row];\n                CUTE_UNROLL\n                for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) {\n                    cp_async_cacheglobal_l2_prefetch_256B(\n                        my_gKV_base + token_index + tile_idx*64,\n                        my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64),\n                        is_token_valid[buf_idx][local_row],\n                        cache_policy\n                    );\n                }\n            }\n        };\n\n        auto commit_to_mbar = [&](transac_bar_t &bar) {\n            cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar));\n        };\n\n        int cur_bar_wait_phase = 1;\n\n        CUTE_NO_UNROLL\n        for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {\n            load_token_indices(block_idx);\n\n            // V0L\n            plan.bar_k0_free[0].wait(cur_bar_wait_phase);\n            copy_tiles(block_idx+0, 0, 0, 4);\n            commit_to_mbar(plan.bar_k0_ready[0]);\n\n            // V1R\n            plan.bar_k1_free[1].wait(cur_bar_wait_phase);\n            copy_tiles(block_idx+1, 1, 4, D_K/64);\n            commit_to_mbar(plan.bar_k1_ready[1]);\n            \n            // V0R\n            plan.bar_k0_free[1].wait(cur_bar_wait_phase);\n            copy_tiles(block_idx+0, 0, 4, D_K/64);\n            commit_to_mbar(plan.bar_k0_ready[1]);\n\n            // V1L\n            plan.bar_k1_free[0].wait(cur_bar_wait_phase);\n            copy_tiles(block_idx+1, 1, 0, 4);\n            commit_to_mbar(plan.bar_k1_ready[0]);\n\n            // Valid mask\n            // NOTE: V1R's finish implies maskings of the last round have finished\n            if (idx_in_group == 0) {\n                CUTE_UNROLL\n                for (int buf_idx = 0; buf_idx < 2; ++buf_idx)\n                    CUTE_UNROLL\n                    for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)\n                        plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];\n                plan.bar_is_kv_valid_ready.arrive();\n            }\n\n            cur_bar_wait_phase ^= 1;\n        }\n    }\n\n\n#else\n    if (cute::thread0()) {\n        CUTE_INVALID_CONTROL_PATH(\"This kernel only supports sm90\");\n    }\n#endif\n}\n\ntemplate<typename Kernel, typename TMAParams>\n__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1)\nsparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TMAParams tma_params) {\n    Kernel::devfunc(params, tma_params);\n}\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\nvoid KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &params) {\n    KU_ASSERT(params.h_kv == 1);\n    KU_ASSERT(params.topk % (2*B_TOPK) == 0);   // To save some boundry checkings\n    KU_ASSERT(params.topk > 0);\n    KU_ASSERT(params.h_q % B_H == 0);\n\n    auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);\n    auto tma_Q = cute::make_tma_copy(\n        SM90_TMA_LOAD{},\n        make_tensor(\n            make_gmem_ptr((bf16*)params.q),\n            make_layout(\n                shape_Q,\n                make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)\n            )\n        ),\n        SmemLayoutQ{}\n    );\n\n    CUtensorMap tensor_map_O;\n    {\n        uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q};\n        uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)};\n        uint32_t box_size[3] = {64, B_H, 1};\n        uint32_t elem_stride[3] = {1, 1, 1};\n        CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(\n            &tensor_map_O,\n            CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,\n            3,\n            params.out,\n            size,\n            stride,\n            box_size,\n            elem_stride,\n            CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n            CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,\n            CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,\n            CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE\n        );\n        KU_ASSERT(res == CUresult::CUDA_SUCCESS);\n    }\n\n    TmaParams<\n        decltype(shape_Q), decltype(tma_Q)\n    > tma_params = {\n        shape_Q, tma_Q,\n        tensor_map_O\n    };\n    auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>, decltype(tma_params)>;\n\n    constexpr size_t smem_size = sizeof(SharedMemoryPlan);\n    KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n\n    cutlass::ClusterLaunchParams launch_params = {\n        dim3((params.h_q/B_H)*params.s_q, 1, 1),    // NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)\n        dim3(NUM_THREADS, 1, 1),\n        dim3(1, 1, 1),\n        smem_size,\n        params.stream\n    }; \n    cutlass::launch_kernel_on_cluster(\n        launch_params, (void*)kernel, params, tma_params\n    );\n    KU_CHECK_KERNEL_LAUNCH();\n}\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {\n    KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);\n}\n\n}\n"
  },
  {
    "path": "csrc/sm90/prefill/sparse/phase1.h",
    "content": "#pragma once\n\n#include \"../../../params.h\"\n\nnamespace sm90::fwd {\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\nvoid run_fwd_phase1_kernel(const SparseAttnFwdParams& params);\n\n}\n"
  },
  {
    "path": "csrc/smxx/decode/combine/combine.cu",
    "content": "#include \"combine.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n\n#include <kerutils/kerutils.cuh>\n\n#include \"params.h\"\n#include \"utils.h\"\n\nusing namespace cute;\n\nnamespace smxx::decode {\n\ntemplate<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>\n__global__ void __launch_bounds__(NUM_THREADS)\nflash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) {\n    // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]\n    // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result\n    static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m\n    const int batch_idx = blockIdx.x;\n    const int s_q_idx = blockIdx.y;\n    const int h_block_idx = blockIdx.z;\n    const int warp_idx = threadIdx.x / 32;\n    const int lane_idx = threadIdx.x % 32;\n\n    int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);\n    if (warp_idx >= num_valid_heads) {\n        return;\n    }\n\n    const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);\n    const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);\n    const int my_num_splits = end_split_idx - start_split_idx;\n    if (my_num_splits == 1) {\n        return;\n    }\n    \n    FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);\n    \n    Tensor gLseAccum = make_tensor(\n        make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),\n        Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},\n        make_stride(params.stride_lse_accum_split, _1{})\n    );\n    Tensor gLse = make_tensor(\n        make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),\n        Shape<Int<BLOCK_SIZE_M>>{},\n        Stride<_1>{}\n    );\n    \n    __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];\n\n    // Wait for the previous kernel (the MLA kernel) to finish\n    cudaGridDependencySynchronize();\n\n    // Prefetch\n    static_assert(HEAD_DIM_V % (32*4) == 0);\n    constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);\n    float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;\n    float4 datas[ELEMS_PER_THREAD];\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < ELEMS_PER_THREAD; ++i) {\n        datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL\n    }\n\n    // Warp #i gathers LseAccum for seq #i\n    {\n        constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);\n        float local_lse[NUM_LSE_PER_THREAD];\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {\n            const int split_idx = i*32 + lane_idx;\n            local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;\n        }\n\n        float max_lse = -INFINITY;\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)\n            max_lse = max(max_lse, local_lse[i]);\n        CUTLASS_PRAGMA_UNROLL\n        for (int offset = 16; offset >= 1; offset /= 2)\n            max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));\n        max_lse = max_lse == -INFINITY ? 0.0f : max_lse;  // In case all local LSEs are -inf\n\n        float sum_lse = 0;\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)\n            sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);\n        CUTLASS_PRAGMA_UNROLL\n        for (int offset = 16; offset >= 1; offset /= 2)\n            sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);\n\n        float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;\n        if (lane_idx == 0)\n            gLse(warp_idx) = global_lse / (float)M_LOG2E;\n        \n        if (params.attn_sink != nullptr) {\n            int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;\n            float attn_sink = __ldg(params.attn_sink + q_head_idx);\n            if (global_lse != INFINITY) {\n                // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)\n                // If attn_sink is -inf, this has no effect on global_lse\n                global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));\n            } else {\n                // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)\n                global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;\n            }\n        }\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {\n            const int split_idx = i*32 + lane_idx;\n            smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);\n        }\n    }\n\n    __syncwarp();\n\n    // Warp #i accumulates activation for seq #i\n    {\n        float4 result[ELEMS_PER_THREAD];\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < ELEMS_PER_THREAD; ++i)\n            result[i] = {0.0f, 0.0f, 0.0f, 0.0f};\n\n        #pragma unroll 1\n        for (int split = 0; split < my_num_splits; ++split) {\n            float lse_scale = smem_buf[warp_idx][split];\n            // if (lse_scale != 0.f) {\n            CUTLASS_PRAGMA_UNROLL\n            for (int i = 0; i < ELEMS_PER_THREAD; ++i) {\n                result[i].x += lse_scale * datas[i].x;\n                result[i].y += lse_scale * datas[i].y;\n                result[i].z += lse_scale * datas[i].z;\n                result[i].w += lse_scale * datas[i].w;\n                if (split != my_num_splits-1) {\n                    datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);\n                }\n            }\n            // }\n        }\n        \n        const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;\n        ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;\n\n        CUTLASS_PRAGMA_UNROLL\n        for (int i = 0; i < ELEMS_PER_THREAD; ++i) {\n            float4 data = result[i];\n            ElementT data_converted[4];\n            data_converted[0] = (ElementT)(data.x);\n            data_converted[1] = (ElementT)(data.y);\n            data_converted[2] = (ElementT)(data.z);\n            data_converted[3] = (ElementT)(data.w);\n            static_assert(sizeof(ElementT) == 2);\n            *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;\n        }\n    }\n}\n\n\n#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...)       \\\n    [&] {                                                  \\\n        if (NUM_SPLITS <= 32) {                            \\\n            constexpr static int NAME = 32;                \\\n            return __VA_ARGS__();                          \\\n        } else if (NUM_SPLITS <= 64) {                     \\\n            constexpr static int NAME = 64;                \\\n            return __VA_ARGS__();                          \\\n        } else if (NUM_SPLITS <= 96) {                     \\\n            constexpr static int NAME = 96;                \\\n            return __VA_ARGS__();                          \\\n        } else if (NUM_SPLITS <= 128) {                    \\\n            constexpr static int NAME = 128;               \\\n            return __VA_ARGS__();                          \\\n        } else if (NUM_SPLITS <= 160) {                    \\\n            constexpr static int NAME = 160;               \\\n            return __VA_ARGS__();                          \\\n        } else {                                           \\\n            FLASH_ASSERT(false);                           \\\n        }                                                  \\\n    }()\n\n\ntemplate<typename ElementT>\nvoid run_flash_mla_combine_kernel(CombineParams &params) {\n    static constexpr int HEAD_DIM_V = 512;  // Since only this head dimension is supported by Flash MLA\n    FLASH_ASSERT(params.d_v == HEAD_DIM_V);\n    MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {\n        constexpr int BLOCK_SIZE_M = 8;\n        constexpr int NUM_THREADS = BLOCK_SIZE_M*32;\n        constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);\n        auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;\n        CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n        // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)\n        cudaLaunchAttribute attribute[1];\n        attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;\n        attribute[0].val.programmaticStreamSerializationAllowed = 1;\n        cudaLaunchConfig_t combine_kernel_config = {\n            dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),\n            dim3(NUM_THREADS, 1, 1),\n            0,\n            params.stream,\n            attribute,\n            1\n        };\n        CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));\n    });\n    CHECK_CUDA_KERNEL_LAUNCH();\n}\n\ntemplate void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(CombineParams &params);\n\n#ifndef FLASH_MLA_DISABLE_FP16\ntemplate void run_flash_mla_combine_kernel<cutlass::half_t>(CombineParams &params);\n#endif\n\n}\n"
  },
  {
    "path": "csrc/smxx/decode/combine/combine.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace smxx::decode {\n\ntemplate<typename ElementT>\nvoid run_flash_mla_combine_kernel(CombineParams &params);\n\n}\n"
  },
  {
    "path": "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
    "content": "#include \"get_decoding_sched_meta.h\"\n\n#include <cuda_runtime_api.h>\n#include <cutlass/fast_math.h>\n#include <kerutils/kerutils.cuh>\n\n#include \"utils.h\"\n\nnamespace smxx::decode {\n\n__global__ void __launch_bounds__(32, 1, 1)\nget_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) {\n    int *seqlens_k_ptr = params.seqlens_k_ptr;\n    DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;\n    int *num_splits_ptr = params.num_splits_ptr;\n    int batch_size = params.b;\n    int block_size_n = params.block_size_n;\n    int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;\n    int num_sm_parts = params.num_sm_parts;\n\n    extern __shared__ int shared_mem[];\n    int* num_blocks_shared = shared_mem; // [batch_size]\n    int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]\n    int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size]\n    int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size]\n    int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size]\n\n    int total_num_blocks = 0;\n    for (int i = threadIdx.x; i < batch_size; i += 32) {\n        int cur_s_k;\n        if (params.topk == -1) {\n            // Dense model, cur_s_k = actual s_k\n            cur_s_k = __ldg(seqlens_k_ptr + i);\n        } else {\n            // Sparse model, cur_s_k = topk (+ extra topk)\n            cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk;\n            if (cur_s_k == 0) cur_s_k = 1;  // Ensure the main loop will never be empty\n            if (params.extra_topk) {\n                cur_s_k = ku::ceil(cur_s_k, block_size_n);\n                cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk;\n            }\n        }\n        seqlens_k_shared[i] = cur_s_k;\n        int first_token_idx = 0;\n        int last_token_idx = max(cur_s_k-1, 0);\n        int cur_first_block_idx = first_token_idx / block_size_n;\n        int cur_last_block_idx = last_token_idx / block_size_n;\n        // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]\n        // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.\n        int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;\n        total_num_blocks += num_blocks + fixed_overhead_num_blocks;\n        num_blocks_shared[i] = num_blocks;\n        first_block_idx_shared[i] = cur_first_block_idx;\n        last_block_idx_shared[i] = cur_last_block_idx;\n    }\n    for (int offset = 16; offset >= 1; offset /= 2) {\n        total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);\n    }\n    __syncwarp();\n\n    if (threadIdx.x == 0) {\n        int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;\n\n        int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;\n        num_splits_shared[0] = 0;\n        for (int i = 0; i < num_sm_parts; ++i) {\n            DecodingSchedMeta cur_meta;\n            cur_meta.begin_req_idx = now_req_idx;\n            cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx];\n            cur_meta.begin_split_idx = now_n_split_idx;\n            cur_meta.is_first_req_splitted = (now_block != 0);\n            int remain_payload = payload;\n            while (now_req_idx < batch_size) {\n                int num_blocks = num_blocks_shared[now_req_idx];\n                int now_remain_blocks = num_blocks - now_block;\n                if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {\n                    cum_num_splits += now_n_split_idx + 1;\n                    num_splits_shared[now_req_idx + 1] = cum_num_splits;\n                    remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;\n                    ++now_req_idx;\n                    now_block = 0;\n                    now_n_split_idx = 0;\n                } else {\n                    if (remain_payload - fixed_overhead_num_blocks > 0) {\n                        now_block += remain_payload - fixed_overhead_num_blocks;\n                        ++now_n_split_idx;\n                        remain_payload = 0;\n                    }\n                    break;\n                }\n            }\n            cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1;\n            cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1);\n            cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0;\n            if (cur_meta.begin_req_idx == cur_meta.end_req_idx) {\n                cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted;\n            }\n            tile_scheduler_metadata_ptr[i] = cur_meta;\n        }\n        FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);\n    }\n    __syncwarp();\n\n    for (int i = threadIdx.x; i <= batch_size; i += 32) {\n        num_splits_ptr[i] = num_splits_shared[i];\n    }\n}\n\nvoid run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) {\n    int smem_size = sizeof(int) * (params.b*5+1);\n    CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params);\n    CHECK_CUDA_KERNEL_LAUNCH();\n}\n\n}\n"
  },
  {
    "path": "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h",
    "content": "#pragma once\n\n#include \"params.h\"\n\nnamespace smxx::decode {\n\nvoid run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);\n\n}\n"
  },
  {
    "path": "csrc/utils.h",
    "content": "#pragma once\n\n#include <cstdint>\n\n#define CHECK_CUDA(call)                                                                                  \\\n    do {                                                                                                  \\\n        cudaError_t status_ = call;                                                                       \\\n        if (status_ != cudaSuccess) {                                                                     \\\n            fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n            exit(1);                                                                              \\\n        }                                                                                                 \\\n    } while(0)\n\n#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())\n\n\n#define FLASH_ASSERT(cond)                                                                                \\\n    do {                                                                                                  \\\n        if (not (cond)) {                                                                                 \\\n            fprintf(stderr, \"Assertion failed (%s:%d): %s\\n\", __FILE__, __LINE__, #cond);                 \\\n            exit(1);                                                                                      \\\n        }                                                                                                 \\\n    } while(0)\n\n\n#define FLASH_DEVICE_ASSERT(cond)                                                                         \\\n    do {                                                                                                  \\\n        if (not (cond)) {                                                                                 \\\n            printf(\"Assertion failed (%s:%d): %s\\n\", __FILE__, __LINE__, #cond);                          \\\n            asm(\"trap;\");                                                                                 \\\n        }                                                                                                 \\\n    } while(0)\n\n#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print(\"\\n\"); }\n\ntemplate<typename T>\n__inline__ __host__ __device__ T ceil_div(const T &a, const T &b) {\n    return (a + b - 1) / b;\n}\n\n#ifndef TRAP_ONLY_DEVICE_ASSERT\n#define TRAP_ONLY_DEVICE_ASSERT(cond) \\\ndo { \\\n    if (not (cond)) \\\n        asm(\"trap;\"); \\\n} while (0)\n#endif\n\n#ifndef TRAP_ONLY_DEVICE_ASSERT\n#define TRAP_ONLY_DEVICE_ASSERT(cond) \\\ndo { \\\n    if (not (cond)) \\\n        asm(\"trap;\"); \\\n} while (0)\n#endif\n\n\nstruct RingBufferState {\n    uint32_t cur_block_idx = 0u;\n\n    __device__ __forceinline__\n    void update() {\n        cur_block_idx += 1;\n    }    \n\n    template<uint32_t NUM_STAGES>\n    __device__ __forceinline__\n    std::pair<uint32_t, bool> get() const {\n        uint32_t stage_idx = cur_block_idx % NUM_STAGES;\n        bool phase = (cur_block_idx / NUM_STAGES) & 1;\n        return {stage_idx, phase};\n    }\n\n    __device__ __forceinline__\n    RingBufferState offset_by(const int offset) const {\n        // Must guarantee no underflow\n        uint32_t new_block_idx = static_cast<uint32_t>(static_cast<int>(cur_block_idx) + offset);\n        RingBufferState new_state;\n        new_state.cur_block_idx = new_block_idx;\n        return new_state;\n    }\n};\n"
  },
  {
    "path": "docs/20250422-new-kernel-deep-dive.md",
    "content": "# A Deep-Dive Into the New Flash MLA Kernel\n\nIn the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) of the Flash MLA kernel, we have achieved impressive performance: 3000 GB/s in memory-intensive settings and 580 TFlops in compute-bound settings. Now, we're pushing these numbers even further, reaching up to 660 TFlops.\n\nIn this blog, we present a deep dive into the new kernel, explaining the optimizations and techniques behind this performance boost. We'll first explain why the MLA kernel is compute-bound despite being a decoding-stage attention kernel, then discuss our high-level kernel schedule design, and finally cover the technical details of the new kernel.\n\n## A Theoretical Analysis of the MLA Algorithm\n\nGPU kernels can be classified as either compute-bound (limited by floating-point operations per second, FLOPs) or memory-bound (limited by memory bandwidth). To identify the kernel's bottleneck, we calculate the ratio of FLOPs to memory bandwidth (FLOPs/byte) and compare it with the GPU's capacity.\n\nAssume the number of q heads is $h_q$, the number of q tokens per request is $s_q$ (should be 1 if MTP / speculative decoding is disabled), the number of kv tokens per request is $s_k\\ (s_k \\gg h_q s_q)$, and the head dimensions of K and V are $d_k$ and $d_v$ respectively. The number of FLOPs is roughly $2 (h_q s_q \\cdot d_k \\cdot s_k + h_q s_q \\cdot s_k \\cdot d_v) = 2 h_q s_q s_k (d_k+d_v)$, and the memory access volume (in bytes) is $\\mathop{\\text{sizeof}}(\\text{bfloat16}) \\times (h_q s_q d_k + s_k d_k + h_q s_q d_v) \\approx 2s_k d_k$. Thus, the compute-memory ratio is $h_q s_q \\cdot \\frac{d_k+d_v}{d_k} \\approx 2 h_q s_q$.\n\nAn NVIDIA H800 SXM5 GPU has a peak memory bandwidth of 3.35 TB/s and peak FLOPs of 990 TFlops. However, due to throttling (reducing to ~1600 MHz in our case), the practical peak FLOPs drops to ~865 TFlops. Therefore, when $h_qs_q \\ge \\frac{1}{2} \\cdot \\frac{865}{3.35} = 128$, the kernel is compute-bound; otherwise, it's memory-bound.\n\nAccording to [the overview of DeepSeek's Online Inference System](https://github.com/deepseek-ai/open-infra-index/blob/main/202502OpenSourceWeek/day_6_one_more_thing_deepseekV3R1_inference_system_overview.md), we don't use Tensor Parallel for decoding instances, meaning $h_q$ is 128 and the kernel is compute-bound. Thus, we need to optimize the kernel for compute-bound settings.\n\n## High-Level Design of the New Kernel\n\nTo fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's \"schedule.\"\n\n[FlashAttention-3's paper](https://arxiv.org/abs/2407.08608) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \\times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possibility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation.\n\n(You might pause here to ponder - perhaps you can find a better solution than ours!)\n\nOur solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \\times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \\times 256$). The output matrix is then computed as follows:\n\n0. Maintain a running max $m$ (initialized to $-\\infty$, shared between the two warpgroups) and output matrices $\\vec o_L, \\vec o_R$ (initialized to 0).\n1. [0] Compute $`\\vec p_0 = \\vec q K_0^\\intercal / qk\\_scale`$.\n2. [1] Compute $`\\vec p_1 = \\vec q K_1^\\intercal / qk\\_scale`$.\n3. [0] Compute $mp_0 = \\max(\\vec p_0)$, $`m\\_new_0 = \\max(m, mp_0)`$, and $`scale_0 = \\exp(m\\_new_0 - m)`$. Update $`m \\gets m\\_new_0`$.\n4. [0] Perform softmax on $\\vec p_0$: $`\\vec p_0 \\gets \\exp(\\vec p_0 - m\\_new_0)`$.\n5. [0] Update $\\vec o_L \\gets \\vec o_L \\cdot scale_0 + \\vec p_0 V_{0L}$.\n6. [1] Compute $mp_1 = \\max(\\vec p_1)$, $`m\\_new_1 = \\max(m, mp_1)`$, and $`scale_1 = \\exp(m\\_new_1 - m)`$. Update $`m \\gets m\\_new_1`$.\n7. [1] Perform softmax on $\\vec p_1$: $`\\vec p_1 \\gets \\exp(\\vec p_1 - m\\_new_1)`$.\n8. [1] Update $\\vec o_R \\gets \\vec o_R \\cdot (scale_0 \\cdot scale_1) + \\vec p_1 V_{1R}$.\n9. [0] Update $\\vec p_0 \\gets \\vec p_0 \\cdot scale_1$.\n10. [1] Update $\\vec o_R \\gets \\vec o_R + \\vec p_0 V_{0R}$.\n11. [0] Update $\\vec o_L \\gets \\vec o_L \\cdot scale_1 + \\vec p_1 V_{1L}$.\n\nNote: We assume one q head for simplicity, so $\\vec q$ and $\\vec o$ are vectors. Bracketed numbers indicate the warpgroup performing the operation. Assume $\\vec o_L$ resides in warpgroup 0's register and $\\vec o_R$ resides in warpgroup 1's register.\n\nThis schedule can be viewed as a \"ping-pong\" variant using one output matrix—we call it \"seesaw\" scheduling. It's mathematically equivalent to FlashAttention's online softmax algorithm. This schedule allows us to overlap CUDA Core and Tensor Core operations by interleaving the two warpgroups, and also allows us to overlap memory access with computation since we can launch the corresponding Tensor Memory Accelerator (TMA) instructions right after data is no longer needed.\n\nThe complete schedule is shown below (remember that in MLA, $K$ and $V$ are the same with different names):\n\n![MLA Kernel Sched](assets/MLA%20Kernel%20Sched.drawio.svg)\n\n## Discussion of Technical Details\n\nThis section covers technical details of the new kernel.\n\nFirst, although the kernel targets compute-bound scenarios (where memory bandwidth isn't the bottleneck), we can't ignore memory latency. If the data is not ready when we want to use it, we have to wait. To solve this problem, we employ the following techniques:\n\n- **Fine-grained TMA copy - GEMM pipelining:** For a $64 \\times 576$ K block, we launch 9 TMA copies (each moving a $64 \\times 64$ block). GEMM operations begin as soon as each TMA copy completes (When the first TMA copy is done, we can start the first GEMM operation, and so on), improving memory latency tolerance.\n- **Cache hints:** Using `cute::TMA::CacheHintSm90::EVICT_FIRST` for TMA copies improves L2 cache hit rates, as shown by experiments.\n\nThese optimizations achieve up to 80% Tensor Core utilization (of the throttled theoretical peak) and 3 TB/s memory bandwidth on an H800 SXM5 GPU. While slightly slower (~2%) than the old ping-pong buffer version in memory-bound settings, this is acceptable.\n\nOther performance improvements include:\n- **Programmatic Dependent Launch.** We use [programmatic dependent launch](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) to overlap `splitkv_mla` and `combine` kernels.\n- **Tile Scheduler.** We implement a tile scheduler to allocate jobs (requests and blocks) to SMs. This ensures a balanced load across SMs.\n\n## Acknowledgements\n\nFlashMLA's algorithm and scheduling are inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work.\n\n## Citation\n\n```bibtex\n@misc{flashmla2025,\n      title={FlashMLA: Efficient MLA decoding kernels},\n      author={Jiashi Li, Shengyu Liu},\n      year={2025},\n      publisher = {GitHub},\n      howpublished = {\\url{https://github.com/deepseek-ai/FlashMLA}},\n}\n```\n"
  },
  {
    "path": "docs/20250929-hopper-fp8-sparse-deep-dive.md",
    "content": "# A Deep Dive Into The Flash MLA FP8 Decoding Kernel on Hopper\n\nWith the release of DeepSeek-V3.2, we have doubled the context length of our models from 64K tokens to 128K tokens. This puts significant pressure on GPU memory (a single request with 128K tokens requires a KVCache of size $576 \\times 2 \\times 62 \\times 128 \\times 1024 = 8.72\\ \\mathrm{GiB}$), which can lead to out-of-memory (OOM) errors or under-utilized GPUs due to small batch sizes. To address this, we introduced FP8 KVCache for DeepSeek-V3.2.\n\nHowever, writing a high-performance decoding kernel is challenging due to the need for dequantization and its sparse memory access patterns. In this blog, we share the story behind our new FP8 sparse decoding kernel for Hopper GPUs. We will first explain our FP8 KVCache format, then provide a theoretical analysis of clock cycles, and finally detail the techniques used in our new kernel.\n\n## The FP8 KVCache Format\n\nRecall that the decoding phase of the Multi-head Latent Attention (MLA) algorithm operates similarly to Multi-Query Attention (MQA), with 128 query heads and 1 key head, where `head_dim_k = 576` and `head_dim_v = 512` respectively. To reduce the size of the KVCache while maintaining accuracy, we use a fine-grained quantization method. Specifically, we apply tile-level quantization (with a tile size of $1 \\times 128$) to the first 512 elements in each token's KV Cache. This results in 512 `float8_e4m3` values and 4 `float32` scale factors. For the remaining 64 elements (the RoPE part), we do not apply quantization as they are sensitive to precision loss. Therefore, in GPU memory, each token's KVCache occupies 656 bytes, consisting of 512 `float8_e4m3`s, 4 `float32`s, and 64 `bfloat16`s.\n\nInside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bfloat16`s. We then concatenate them with the 64 original `bfloat16` values from the RoPE part. Finally, we perform the MQA calculation using matrix multiplication-add (MMA) operations in `bfloat16` precision (i.e., the inputs to the MMAs are in `bfloat16` and the outputs are in `float32`. This applies to both the QK gemm and the attention-score-V gemm).\n\n## Theoretical Analysis of Clock Cycles\n\nThe main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up.\n\nThe basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). In our kernel, each CTA runs on one SM, and each SM is only mapped to one CTA. If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \\times (576+512) \\times 2 / 4096 \\approx 34$ cycles for MMA operations per K/V token.\n\nHowever, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps:\n1.  Convert `float8_e4m3` to `half`\n2.  Convert `half` to `float32`\n3.  Convert `float32` to `bfloat16`\n4.  Multiply the converted `bfloat16` value by the `float32` scale factor\n\nAccording to [NVIDIA's documentation](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#throughput-of-native-arithmetic-instructions), we need at least $(\\frac{1}{64} + \\frac{1}{64} + \\frac{1}{16} + \\frac{1}{256}) \\times 512 \\approx 50$ cycles for dequantizing each token! This is significantly more than the 34 cycles required for the MMA operations, meaning the kernel is **dequantization-bound**. If left unaddressed, dequantization would become the performance bottleneck, leaving the powerful Tensor Cores underutilized.\n\n## Crossover\n\nBefore we continue, it's important to note a key fact: every query head within the same query token attends to the same key heads, because this is Multi-Query Attention (MQA).\n\nRecall that each CTA processes 64 query heads, while DeepSeek-V3.2 has a total of 128 query heads. If we can find a way to \"share\" the dequantized K/V values between two CTAs that are processing different sets of query heads, then each CTA would only need to dequantize **half** of the KV cache – which is fantastic! We call this method \"crossover\", since the idea was actually inspired by [Chromosomal crossover](https://en.wikipedia.org/wiki/Chromosomal_crossover) during [Meiosis](https://en.wikipedia.org/wiki/Meiosis).\n\nThe next question is, how do we implement this in CUDA? Before NVIDIA's Hopper architecture, the only options for data exchange between CTAs were global memory or the L2 cache, which are slow. However, the powerful Distributed Shared Memory gave us a new solution.\n\n## Distributed Shared Memory to the Rescue\n\nDistributed Shared Memory (DSM) is a new feature introduced with the Hopper architecture, alongside the CTA Cluster (thread block cluster). CTAs within the same cluster can directly access each other's shared memory. For more details, you can refer to [NVIDIA Hopper Architecture In-Depth](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/).\n\nHere is how we use it: We launch CTAs in clusters of size 2. Each CTA within a cluster is responsible for 64 query heads from the same query token. Each CTA performs the following steps:\n1.  Loads *half* of the quantized K/V from global memory. We use a wide `__ldg` load with a width of 128 bits to improve performance.\n2.  Dequantizes its assigned half on the CUDA Cores.\n3.  Stores the dequantized K/V into its own shared memory.\n4.  Simultaneously uses `st.async` to write the dequantized K/V into the shared memory of the other CTA in the cluster.\n\nFor synchronization between these operations, we rely on the [cluster transaction barrier](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/), another powerful programming primitive available in CTA Clusters. After the data exchange is complete, each CTA has the *full* set of dequantized K and V values available in its own shared memory, which it can then use to perform the MMA operations.\n\n## Performance\nUsing these techniques, we achieved 410 TFLOPS in a compute-bound configuration (batch_size=128, num_heads=128, s_q=2, topk=2048) on H800 SXM5 GPUs. This is a significant improvement over the 250 TFLOPS achieved by our previous FP8 sparse decoding kernel without the crossover technique. \n\nAlthough this number is still below the 640 TFLOPS peak of our previous bfloat16 dense decoding kernel, one reason is that it's a **sparse** kernel, and its topk is only 2048. With a smaller topk, the relative overhead of the kernel's prologue and epilogue becomes larger compared with dense decoding with long context length. If we set topk to a larger value, such as 32768, this kernel can achieve up to 460 TFLOPS.\n\nFrom another perspective, the execution time of this kernel in the configuration mentioned above is comparable to that of the dense decoding kernel when the sequence length is around 3000. When the sequence length exceeds 3000, the performance advantage of our new kernel becomes even more significant. This also highlights the effectiveness of our DeepSeek Sparse Attention algorithm.\n"
  },
  {
    "path": "flash_mla/__init__.py",
    "content": "__version__ = \"1.0.0\"\n\nfrom flash_mla.flash_mla_interface import (\n    get_mla_metadata,\n    flash_mla_with_kvcache,\n    flash_attn_varlen_func,\n    flash_attn_varlen_qkvpacked_func,\n    flash_attn_varlen_kvpacked_func,\n    flash_mla_sparse_fwd\n)\n\n__all__ = [\n    \"get_mla_metadata\",\n    \"flash_mla_with_kvcache\",\n    \"flash_attn_varlen_func\",\n    \"flash_attn_varlen_qkvpacked_func\",\n    \"flash_attn_varlen_kvpacked_func\",\n    \"flash_mla_sparse_fwd\"\n]\n"
  },
  {
    "path": "flash_mla/flash_mla_interface.py",
    "content": "from typing import Optional, Tuple\nimport dataclasses\n\nimport torch\n\nimport flash_mla.cuda as flash_mla_cuda\n\n@dataclasses.dataclass\nclass FlashMLASchedMeta:\n    \"\"\"\n    A class that stores the tile scheduler metadata of FlashMLA\n    \"\"\"\n\n    @dataclasses.dataclass\n    class Config:\n        b: int\n        s_q: int\n        h_q: int\n        page_block_size: int\n        h_k: int\n\n        causal: bool\n        is_fp8_kvcache: bool\n        topk: Optional[int]\n\n        extra_page_block_size: Optional[int]\n        extra_topk: Optional[int]\n\n    have_initialized: bool = False\n\n    config: Optional[Config] = None\n\n    tile_scheduler_metadata: Optional[torch.Tensor] = None   # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.\n    num_splits: Optional[torch.Tensor] = None                # (1), dtype torch.int32.\n\n\ndef get_mla_metadata(\n    *args,\n    **kwargs\n) -> Tuple[FlashMLASchedMeta, None]:\n    \"\"\"\n    Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.\n\n    Arguments:\n        This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.\n\n    Return:\n        A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.\n    \"\"\"\n    return FlashMLASchedMeta(), None\n\n\ndef flash_mla_with_kvcache(\n    q: torch.Tensor,\n    k_cache: torch.Tensor,\n    block_table: Optional[torch.Tensor],\n    cache_seqlens: Optional[torch.Tensor],\n    head_dim_v: int,\n    tile_scheduler_metadata: FlashMLASchedMeta,\n    num_splits: None = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    is_fp8_kvcache: bool = False,\n    indices: Optional[torch.Tensor] = None,\n    attn_sink: Optional[torch.Tensor] = None,\n    extra_k_cache: Optional[torch.Tensor] = None,\n    extra_indices_in_kvcache: Optional[torch.Tensor] = None,\n    topk_length: Optional[torch.Tensor] = None,\n    extra_topk_length: Optional[torch.Tensor] = None\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Arguments:\n        q: (batch_size, seq_len_q, num_heads_q, head_dim).\n        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).\n                Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.\n                The KV cache must be contiguously valid for sparse attention on sm100. Here \"contiguously valid\" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.\n        block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.\n        cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.\n        head_dim_v: Head_dim of v. Must be 512\n        sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.\n        num_splits_placeholder: must be \"None\" (to be compatible with the old interface).\n        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).\n        causal: bool. Whether to apply causal attention mask. Only valid for dense attention\n        is_fp8_kvcache: bool.\n        indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.\n                    Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),\n                    where t is the k-th token of the j-th q-sequence in the i-th batch.\n        attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.\n        extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.\n        topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.\n    \n    For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:\n        head_dim should be 576 while head_dim_v should be 512.\n        In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:\n            - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.\n            - First 512 bytes: The \"quantized NoPE\" part, containing 512 float8_e4m3 values.\n            - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.\n            - Last 128 bytes: The \"RoPE\" part, containing 64 bfloat16 values. This part is not quantized for accuracy.\n\n    Return:\n        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).\n        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.\n    \"\"\"\n    sched_meta = tile_scheduler_metadata\n    indices_in_kvcache = indices\n    assert isinstance(sched_meta, FlashMLASchedMeta), \"tile_scheduler_metadata must be of type FlashMLASchedMeta\"\n    assert num_splits is None, \"num_splits must be None\"\n\n    topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None\n    extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None\n    extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None\n    if softmax_scale is None:\n        softmax_scale = q.shape[-1] ** (-0.5)\n\n    if not sched_meta.have_initialized:\n        # Sanity check. We only perform sanity check during the first invocation to save CPU time.\n        if indices_in_kvcache is not None:\n            assert not causal, \"causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)\"\n            \n        # Initialize the tile scheduler metadata during the first invocation.\n        sched_meta.have_initialized = True\n        sched_meta.config = FlashMLASchedMeta.Config(\n            q.shape[0],\n            q.shape[1],\n            q.shape[2],\n            k_cache.shape[1],\n            k_cache.shape[2],\n\n            causal,\n            is_fp8_kvcache,\n            topk,\n\n            extra_k_page_block_size,\n            extra_topk,\n        )\n    else:\n        # Check whether the input arguments are consistent with sched_meta\n        helper_msg = \" Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta.\"\n        assert sched_meta.config is not None\n        assert sched_meta.config.b == q.shape[0], \"sched_meta.config.b must be equal to batch_size.\" + helper_msg\n        assert sched_meta.config.s_q == q.shape[1], \"sched_meta.config.s_q must be equal to seq_len_q.\" + helper_msg\n        assert sched_meta.config.h_q == q.shape[2], \"sched_meta.config.h_q must be equal to num_heads_q.\" + helper_msg\n        assert sched_meta.config.page_block_size == k_cache.shape[1], \"sched_meta.config.page_block_size must be equal to page_block_size.\" + helper_msg\n        assert sched_meta.config.h_k == k_cache.shape[2], \"sched_meta.config.h_k must be equal to num_heads_k.\" + helper_msg\n        assert sched_meta.config.causal == causal, \"sched_meta.config.causal must be equal to causal.\" + helper_msg\n        assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, \"sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache.\" + helper_msg\n        assert sched_meta.config.topk == topk, \"sched_meta.config.topk must be equal to the last dim of indices_in_kvcache.\" + helper_msg\n        assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, \"sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache.\" + helper_msg\n        assert sched_meta.config.extra_topk == extra_topk, \"sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache.\" + helper_msg\n\n    if topk is not None:\n        # Sparse attention\n        assert not causal, \"causal must be False when sparse attention is enabled\"\n        assert is_fp8_kvcache, \"is_fp8_kvcache must be True when sparse attention is enabled\"\n        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(\n            q, k_cache, indices_in_kvcache, topk_length, attn_sink,\n            sched_meta.tile_scheduler_metadata, sched_meta.num_splits,\n            extra_k_cache, extra_indices_in_kvcache, extra_topk_length,\n            head_dim_v, softmax_scale\n        )\n    else:\n        # Dense attention\n        assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, \"indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used.\"\n        assert block_table is not None and cache_seqlens is not None, \"block_table and cache_seqlens must be provided when dense attention is used.\"\n        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(\n            q, k_cache, head_dim_v,\n            cache_seqlens, block_table,\n            softmax_scale, causal,\n            sched_meta.tile_scheduler_metadata, sched_meta.num_splits\n        )\n    sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata\n    sched_meta.num_splits = new_num_splits\n    return (out, lse)\n\n\ndef flash_mla_sparse_fwd(\n    q: torch.Tensor,\n    kv: torch.Tensor,\n    indices: torch.Tensor,\n    sm_scale: float,\n    d_v: int = 512,\n    attn_sink: Optional[torch.Tensor] = None,\n    topk_length: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Sparse attention prefill kernel\n\n    Args:\n        q: [s_q, h_q, d_qk], bfloat16\n        kv: [s_kv, h_kv, d_qk], bfloat16\n        indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv\n        sm_scale: float\n        d_v: The dimension of value vectors. Can only be 512\n        attn_sink: optional, [h_q], float32.\n            If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).\n            +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).\n            This argument has no effect on lse and max_logits.\n        topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).\n            In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.\n\n    Returns:\n        (output, max_logits, lse)\n        Please refer to tests/ref.py for the precise definitions of these parameters.\n        - output: [s_q, h_q, d_v], bfloat16\n        - max_logits:  [s_q, h_q], float\n        - lse: [s_q, h_q], float, log-sum-exp of attention scores\n    \"\"\"\n    results = flash_mla_cuda.sparse_prefill_fwd(\n        q, kv, indices, sm_scale, d_v, attn_sink, topk_length\n    )\n    return results\n\n\ndef _flash_attn_varlen_forward(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_qo: torch.Tensor,\n    cu_seqlens_kv: torch.Tensor,\n    max_seqlen_qo: int,\n    max_seqlen_kv: int,\n    out: Optional[torch.Tensor] = None,\n    lse: Optional[torch.Tensor] = None,\n    causal: bool = False,\n    softmax_scale: Optional[float] = None,\n    is_varlen: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    qo_total_len, num_qo_heads, head_dim_qk = q.shape\n    kv_total_len, num_kv_heads, head_dim_vo = v.shape\n\n    mask_mode_code = 1 if causal else 0\n    if softmax_scale is None:\n        softmax_scale = head_dim_qk ** (-0.5)\n\n    if out is None:\n        out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype)\n    if lse is None:\n        # Make lse contiguous on seqlen dim\n        lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T\n\n    workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)\n    flash_mla_cuda.dense_prefill_fwd(\n        workspace_buffer,\n        q,\n        k,\n        v,\n        cu_seqlens_qo,\n        cu_seqlens_kv,\n        out,\n        lse,\n        mask_mode_code,\n        softmax_scale,\n        max_seqlen_qo,\n        max_seqlen_kv,\n        is_varlen,\n    )\n\n    return out, lse\n\n\ndef _flash_attn_varlen_backward(\n    do: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    lse: torch.Tensor,\n    cu_seqlens_qo: torch.Tensor,\n    cu_seqlens_kv: torch.Tensor,\n    max_seqlen_qo: int,\n    max_seqlen_kv: int,\n    dq: Optional[torch.Tensor] = None,\n    dk: Optional[torch.Tensor] = None,\n    dv: Optional[torch.Tensor] = None,\n    causal: bool = False,\n    softmax_scale: Optional[float] = None,\n    is_varlen: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    qo_total_len, num_qo_heads, head_dim_qk = q.shape\n    kv_total_len, num_kv_heads, head_dim_vo = v.shape\n\n    # TODO: fix bwd GQA\n    if num_qo_heads != num_kv_heads:\n        raise ValueError(f\"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.\")\n\n    mask_mode_code = 1 if causal else 0\n    if softmax_scale is None:\n        softmax_scale = head_dim_qk ** (-0.5)\n\n    if dq is None:\n        dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype)\n    if dk is None:\n        dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype)\n    if dv is None:\n        dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype)\n\n    max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8\n    bs = cu_seqlens_qo.shape[0] - 1\n    workspace_bytes = 0\n    workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk  # dQ_acc\n    workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2  # sum_OdO and scaled_lse\n    if num_qo_heads != num_kv_heads:\n        workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo)  # dKV_acc\n    workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)\n    flash_mla_cuda.dense_prefill_bwd(\n        workspace_buffer,\n        do,\n        q,\n        k,\n        v,\n        out,\n        lse,\n        cu_seqlens_qo,\n        cu_seqlens_kv,\n        dq,\n        dk,\n        dv,\n        mask_mode_code,\n        softmax_scale,\n        max_seqlen_qo,\n        max_seqlen_kv,\n        is_varlen,\n    )\n\n    return dq, dk, dv\n\n\nclass FlashAttnVarlenFunc(torch.autograd.Function):\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens_qo: torch.Tensor,\n        cu_seqlens_kv: torch.Tensor,\n        max_seqlen_qo: int,\n        max_seqlen_kv: int,\n        causal: bool = False,\n        softmax_scale: Optional[float] = None,\n        is_varlen: bool = True,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        out, lse = _flash_attn_varlen_forward(\n            q, k, v,\n            cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,\n            causal=causal, softmax_scale=softmax_scale,\n            is_varlen=is_varlen,\n        )\n        ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv)\n        ctx.max_seqlen_qo = max_seqlen_qo\n        ctx.max_seqlen_kv = max_seqlen_kv\n        ctx.causal = causal\n        ctx.softmax_scale = softmax_scale\n        ctx.is_varlen = is_varlen\n        return out, lse\n\n    def backward(\n        ctx,\n        do: torch.Tensor,\n        dlse: torch.Tensor,\n    ):\n        del dlse  # LSE doesn't support backward currently\n        q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors\n        dq, dk, dv = _flash_attn_varlen_backward(\n            do, q, k, v, out, lse,\n            cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv,\n            causal=ctx.causal, softmax_scale=ctx.softmax_scale,\n            is_varlen=ctx.is_varlen,\n        )\n        return dq, dk, dv, None, None, None, None, None, None, None\n\n\ndef flash_attn_varlen_func(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_qo: torch.Tensor,\n    cu_seqlens_kv: torch.Tensor,\n    max_seqlen_qo: int,\n    max_seqlen_kv: int,\n    dropout_p: float = 0.0,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    deterministic: bool = False,\n    is_varlen: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert dropout_p == 0.0\n    assert not deterministic\n    return FlashAttnVarlenFunc.apply(\n        q, k, v,\n        cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,\n        causal, softmax_scale, is_varlen,\n    )\n\n\ndef flash_attn_varlen_qkvpacked_func(\n    qkv: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    max_seqlen: int,\n    head_dim_qk: int,\n    dropout_p: float = 0.0,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    deterministic: bool = False,\n    is_varlen: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert dropout_p == 0.0\n    assert not deterministic\n    return FlashAttnVarlenFunc.apply(\n        qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:],\n        cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,\n        causal, softmax_scale, is_varlen,\n    )\n\n\ndef flash_attn_varlen_kvpacked_func(\n    q: torch.Tensor,\n    kv: torch.Tensor,\n    cu_seqlens_qo: torch.Tensor,\n    cu_seqlens_kv: torch.Tensor,\n    max_seqlen_qo: int,\n    max_seqlen_kv: int,\n    head_dim_qk: int,\n    dropout_p: float = 0.0,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    deterministic: bool = False,\n    is_varlen: bool = True,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert dropout_p == 0.0\n    assert not deterministic\n    return FlashAttnVarlenFunc.apply(\n        q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:],\n        cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,\n        causal, softmax_scale, is_varlen,\n    )\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nfrom pathlib import Path\nfrom datetime import datetime\nimport subprocess\n\nfrom setuptools import setup, find_packages\n\nfrom torch.utils.cpp_extension import (\n    BuildExtension,\n    CUDAExtension,\n    IS_WINDOWS,\n    CUDA_HOME\n)\n\n\ndef is_flag_set(flag: str) -> bool:\n    return os.getenv(flag, \"FALSE\").lower() in [\"true\", \"1\", \"y\", \"yes\"]\n\ndef get_features_args():\n    features_args = []\n    if is_flag_set(\"FLASH_MLA_DISABLE_FP16\"):\n        features_args.append(\"-DFLASH_MLA_DISABLE_FP16\")\n    return features_args\n\ndef get_arch_flags():\n    # Check NVCC Version\n    # NOTE The \"CUDA_HOME\" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`\n    assert CUDA_HOME is not None, \"PyTorch must be compiled with CUDA support\"\n    nvcc_version = subprocess.check_output(\n        [os.path.join(CUDA_HOME, \"bin\", \"nvcc\"), '--version'], stderr=subprocess.STDOUT\n    ).decode('utf-8')\n    nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip()\n    major, minor = map(int, nvcc_version_number.split('.'))\n    print(f'Compiling using NVCC {major}.{minor}')\n\n    DISABLE_SM100 = is_flag_set(\"FLASH_MLA_DISABLE_SM100\")\n    DISABLE_SM90 = is_flag_set(\"FLASH_MLA_DISABLE_SM90\")\n    if major < 12 or (major == 12 and minor <= 8):\n        assert DISABLE_SM100, \"sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment.\"    # TODO Implement this\n\n    arch_flags = []\n    if not DISABLE_SM100:\n        arch_flags.extend([\"-gencode\", \"arch=compute_100f,code=sm_100f\"])\n    if not DISABLE_SM90:\n        arch_flags.extend([\"-gencode\", \"arch=compute_90a,code=sm_90a\"])\n    return arch_flags\n\ndef get_nvcc_thread_args():\n    nvcc_threads = os.getenv(\"NVCC_THREADS\") or \"32\"\n    return [\"--threads\", nvcc_threads]\n\nsubprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"csrc/cutlass\"])\n\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\nif IS_WINDOWS:\n    cxx_args = [\"/O2\", \"/std:c++20\", \"/DNDEBUG\", \"/W0\"]\nelse:\n    cxx_args = [\"-O3\", \"-std=c++20\", \"-DNDEBUG\", \"-Wno-deprecated-declarations\"]\n\next_modules = []\next_modules.append(\n    CUDAExtension(\n        name=\"flash_mla.cuda\",\n        sources=[\n            # API\n            \"csrc/api/api.cpp\",\n\n            # Misc kernels for decoding\n            \"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu\",\n            \"csrc/smxx/decode/combine/combine.cu\",\n\n            # sm90 dense decode\n            \"csrc/sm90/decode/dense/instantiations/fp16.cu\",\n            \"csrc/sm90/decode/dense/instantiations/bf16.cu\",\n\n            # sm90 sparse decode\n            \"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu\",\n            \"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu\",\n            \"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu\",\n            \"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu\",\n\n            # sm90 sparse prefill\n            \"csrc/sm90/prefill/sparse/fwd.cu\",\n            \"csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu\",\n            \"csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu\",\n            \"csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu\",\n            \"csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu\",\n\n            # sm100 dense prefill & backward\n            \"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu\",\n            \"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu\",\n\n            # sm100 sparse prefill\n            \"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu\",\n            \"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu\",\n            \"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu\",\n            \"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu\",\n            \"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu\",\n\n            # sm100 sparse decode\n            \"csrc/sm100/decode/head64/instantiations/v32.cu\",\n            \"csrc/sm100/decode/head64/instantiations/model1.cu\",\n            \"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu\",\n        ],\n        extra_compile_args={\n            \"cxx\": cxx_args + get_features_args(),\n            \"nvcc\": [\n                \"-O3\",\n                \"-std=c++20\",\n                \"-DNDEBUG\",\n                \"-D_USE_MATH_DEFINES\",\n                \"-Wno-deprecated-declarations\",\n                \"-U__CUDA_NO_HALF_OPERATORS__\",\n                \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                \"-U__CUDA_NO_HALF2_OPERATORS__\",\n                \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n                \"--expt-relaxed-constexpr\",\n                \"--expt-extended-lambda\",\n                \"--use_fast_math\",\n                \"--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use\",\n                \"-lineinfo\",\n                \"--source-in-ptx\",\n            ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),\n        },\n        include_dirs=[\n            Path(this_dir) / \"csrc\",\n            Path(this_dir) / \"csrc\" / \"kerutils\" / \"include\",   # TODO Remove me\n            Path(this_dir) / \"csrc\" / \"sm90\",\n            Path(this_dir) / \"csrc\" / \"cutlass\" / \"include\",\n            Path(this_dir) / \"csrc\" / \"cutlass\" / \"tools\" / \"util\" / \"include\",\n        ],\n    )\n)\n\ntry:\n    cmd = ['git', 'rev-parse', '--short', 'HEAD']\n    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()\nexcept Exception as _:\n    now = datetime.now()\n    date_time_str = now.strftime(\"%Y-%m-%d-%H-%M-%S\")\n    rev = '+' + date_time_str\n\n\nsetup(\n    name=\"flash_mla\",\n    version=\"1.0.0\" + rev,\n    packages=find_packages(include=['flash_mla']),\n    ext_modules=ext_modules,\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "tests/kernelkit/.gitignore",
    "content": "build\n*.so\n*.egg-info/\n__pycache__/\ndist/\n/.vscode\n.cache\n/temp\n/profiles\n"
  },
  {
    "path": "tests/kernelkit/__init__.py",
    "content": "from . import bench\nfrom . import compare\nfrom . import generate\nfrom . import precision\nfrom . import utils\n\nfrom .bench import bench_kineto, bench_by_cuda_events\nfrom .compare import get_cos_diff, check_is_bitwise_equal, check_is_allclose, check_is_bitwise_equal_comparator, check_is_allclose_comparator\nfrom .generate import gen_non_contiguous_randn_tensor, gen_non_contiguous_tensor, non_contiguousify\nfrom .precision import LowPrecisionMode, is_low_precision_mode, optional_cast_to_bf16_and_cast_back\nfrom .utils import colors, cdiv, is_using_profiling_tools, set_random_seed, Counter\n"
  },
  {
    "path": "tests/kernelkit/bench.py",
    "content": "from typing import Tuple, List, Callable, Union, Dict, overload\nimport dataclasses\n\nimport torch\nimport triton\n\nfrom .utils import is_using_profiling_tools\n\nclass empty_suppress:\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *_):\n        pass\n\n@triton.jit\ndef profiler_range_start_marker_kernel():\n    pass\n\ndef _run_profiler_range_start_marker_kernel():\n    profiler_range_start_marker_kernel[(1,)]()\n\n@dataclasses.dataclass\nclass BenchKinetoRawResult:\n    \"\"\"\n    A struct holding the result of `bench_kineto`\n    \"\"\"\n\n    is_using_nsys: bool\n    num_tests: int\n    time_ranges: Dict[str, List[Tuple[float, float]]]\n\n    def _get_matched_kernel_name(self, name_substr: str, allow_no_match: bool = False, allow_multiple_match: bool = False) -> List[str]:\n        matched_names = [name for name in self.time_ranges.keys() if name_substr in name]\n        if not allow_no_match and len(matched_names) == 0:\n            all_kernel_names_str = '\\n  - ' + '\\n  - '.join(self.time_ranges.keys())\n            raise ValueError(f\"Error: No kernel name matched for substring {name_substr}.\\nAvailable kernels are: {all_kernel_names_str}\")\n        if not allow_multiple_match and len(matched_names) > 1:\n            raise ValueError(f\"Error: Multiple kernel matched for substring {name_substr}: {', '.join(matched_names)}\")\n        return matched_names\n    \n    def get_kernel_names(self) -> List[str]:\n        return list(self.time_ranges.keys())\n    \n    def get_kernel_times(self, kernel_names_substr: List[str], allow_indivisible_run_count: bool = False, allow_missing: bool = False, allow_multiple_match: bool = False, return_avg_individual_run: bool = False) -> List[float]:\n        \"\"\"\n        Get the average each-run time usage of each kernel provided in `kernel_names`\n\n        If return_avg_individual_run is False, return sum(time) / num_tests, else return sum(time) / len(time)\n        If is_using_profiling_tools (which is conflict with bench_kineto), return a series of 1 seconds\n        \"\"\"\n        if is_using_profiling_tools():\n            return [1 for _ in range(len(kernel_names_substr))]\n        \n        result = []\n        for substr in kernel_names_substr:\n            matched_names = self._get_matched_kernel_name(substr, allow_no_match=allow_missing, allow_multiple_match=allow_multiple_match)\n            if len(matched_names) == 0:\n                assert allow_missing\n                result.append(0)\n            else:\n                time_usage_sum = 0\n                run_cnt_sum = 0\n                for matched_name in matched_names:\n                    run_cnt = len(self.time_ranges[matched_name])\n                    if not allow_indivisible_run_count and run_cnt % self.num_tests != 0:\n                        raise RuntimeError(f\"Error: the number of runs for kernel {matched_name} ({run_cnt}) is indivisible by `num_tests` ({self.num_tests})\")\n                    time_usage_sum += sum([end-start for (start, end) in self.time_ranges[matched_name]])\n                    run_cnt_sum += run_cnt\n                denominator = run_cnt_sum if return_avg_individual_run else self.num_tests\n                result.append(time_usage_sum / denominator)\n        return result\n    \n    def get_kernel_time(self, kernel_name_substr: str) -> float:\n        return self.get_kernel_times([kernel_name_substr])[0]\n\n    def get_e2e_time(self, start_kernel_name_substr: str, end_kenrel_name_substr: str) -> float:\n        \"\"\"\n        Get the end-to-end time usage for a sequence of kernels\n        defined as \"last kernel end time\" - \"first kernel start time\"\n        If is_using_profiling_tools (which is conflict with bench_kineto), return 1 second\n        \"\"\"\n        if is_using_profiling_tools():\n            return 1\n        \n        start_kernel_name = self._get_matched_kernel_name(start_kernel_name_substr)[0]\n        end_kernel_name = self._get_matched_kernel_name(end_kenrel_name_substr)[0]\n        num_start_kernels = len(self.time_ranges[start_kernel_name])\n        num_end_kernels = len(self.time_ranges[end_kernel_name])\n        if num_start_kernels%self.num_tests != 0:\n            raise RuntimeError(f\"Error: the number of runs for kernel {start_kernel_name} ({num_start_kernels}) is indivisible by `num_tests` ({self.num_tests})\")\n        if num_end_kernels%self.num_tests != 0:\n            raise RuntimeError(f\"Error: the number of runs for kernel {end_kernel_name} ({num_end_kernels}) is indivisible by `num_tests` ({self.num_tests})\")\n        time_spans = []\n        for i in range(self.num_tests):\n            end_time = self.time_ranges[end_kernel_name][(i+1)*(num_end_kernels//self.num_tests)-1][1]\n            start_time = self.time_ranges[start_kernel_name][i*(num_start_kernels//self.num_tests)][0]\n            time_spans.append((start_time, end_time))\n        result = sum([end-start for (start, end) in time_spans]) / self.num_tests\n        return result\n\n\ndef bench_kineto(fn: Callable, num_tests: int = 30,\n                 flush_l2: bool = True) -> BenchKinetoRawResult:\n    \"\"\"\n    Run `fn` for `num_tests` times under `bench_kineto` (CUPTI), and returns a BenchKinetoRawResult\n    \"\"\"\n    using_nsys = is_using_profiling_tools()\n\n    # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle\n    flush_l2_size = int(8e9 // 4)\n\n    schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None\n    profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()\n    with profiler:\n        for i in range(2):\n            if i == 1 and not using_nsys:\n                _run_profiler_range_start_marker_kernel()    # This marks the start of the profiling range\n            for _ in range(num_tests):\n                if flush_l2:\n                    torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()\n                enable_nvtx_range = i == 1 and _ == num_tests-1\n                if enable_nvtx_range:\n                    torch.cuda.nvtx.range_push(\"profile_target\")\n                fn()\n                if enable_nvtx_range:\n                    torch.cuda.nvtx.range_pop()\n            if not using_nsys:\n                if i == 0:\n                    torch.cuda.synchronize()\n                profiler.step()\n    \n    if using_nsys:\n        return BenchKinetoRawResult(True, num_tests, {})\n\n    from torch.autograd.profiler_util import EventList, FunctionEvent   # pylint: disable=import-outside-toplevel\n    events: EventList = profiler.events() # type: ignore\n\n    # Filter out all events that are not function events\n    events: List[FunctionEvent] = [event for event in events if isinstance(event, FunctionEvent)]\n\n    # Filter out all events before the range marker\n    for idx, event in enumerate(events):\n        if event.name == \"profiler_range_start_marker_kernel\":\n            events = events[idx+1:]\n            break\n    else:\n        raise RuntimeError(\"Could not find profiler range start marker kernel event\")\n\n    # Get time ranges of each kernel\n    kernel_times = {}\n    for event in events:\n        kernel_name = event.name\n        if kernel_name not in kernel_times:\n            kernel_times[kernel_name] = []\n        kernel_times[kernel_name].append((event.time_range.start/1e6, event.time_range.end/1e6))\n    \n    return BenchKinetoRawResult(False, num_tests, kernel_times)\n\n@overload\ndef bench_by_cuda_events(kernels: List[Callable], num_warmups_each: int, num_runs_each: int) -> List[float]: ...\n\n@overload\ndef bench_by_cuda_events(kernels: Callable, num_warmups_each: int, num_runs_each: int) -> float: ...\n\ndef bench_by_cuda_events(kernels: Union[List[Callable], Callable], num_warmups_each: int, num_runs_each: int) -> Union[List[float], float]:\n    buf_for_l2_clear = torch.empty(int(256e6//4), dtype=torch.int32, device='cuda')\n\n    is_kernel_single_callable = isinstance(kernels, Callable)\n    if is_kernel_single_callable:\n        kernels = [kernels]\n\n    torch.cuda.synchronize()\n    for i in range(num_warmups_each):\n        for kernel in kernels:\n            kernel()\n            if i == 0:\n                # Ensure the first run is successful\n                try:\n                    torch.cuda.synchronize()\n                except Exception as e:\n                    print(f\"Kernel {kernel.__name__} failed on warmup run {i}: {e}\")\n                    return []\n\n    start_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels]\n    end_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels]\n    for i in range(num_runs_each):\n        for j, kernel in enumerate(kernels):\n            buf_for_l2_clear.random_()\n            if i == num_runs_each-1:\n                torch.cuda.nvtx.range_push(\"profile_target\")\n            start_events[j][i].record()\n            kernel()\n            end_events[j][i].record()\n            if i == num_runs_each-1:\n                torch.cuda.nvtx.range_pop()\n    \n    torch.cuda.synchronize()\n    time_usages = [\n        sum([start_events[j][i].elapsed_time(end_events[j][i])*1e-3 for i in range(num_runs_each)]) / num_runs_each\n        for j in range(len(kernels))\n    ]\n    if is_kernel_single_callable:\n        time_usages = time_usages[0]\n    return time_usages\n"
  },
  {
    "path": "tests/kernelkit/compare.py",
    "content": "from typing import List\n\nimport torch\n\ndef check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor):\n    \"\"\"\n    Return if two tensors are bitwise equal\n    Return a bool if avoid_sync is False, else return a tensor\n    \"\"\"\n    assert ans.shape == ref.shape, \"Shape mismatch\"\n    torch.all(torch.eq(ans, ref), out=result)\n\ndef check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool:\n    is_bitwise_equal = torch.equal(ans, ref)\n    if not quiet and not is_bitwise_equal:\n        print(f\"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}\")\n    return is_bitwise_equal\n\ndef get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float:\n    \"\"\"\n    Calculate the cosine diff between two tensors\n    Return a float if avoid_sync is False, else return a tensor\n    \"\"\"\n    ans, ref = ans.double(), ref.double()\n    if (ref*ref).sum().item() < 1e-12:\n        return 0\n    denominator = (ans*ans + ref*ref).sum().item()\n    sim = 2 * (ans*ref).sum().item() / denominator\n    return 1 - sim\n\ndef check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool:\n    \"\"\"\n    Check if two tensors are close enough\n    Return a bool if avoid_sync is False, else return a tensor\n    \"\"\"\n    assert ans.shape == ref.shape, f\"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}\"\n    assert ans.dtype == ref.dtype, f\"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}\"\n    \n    ans = ans.clone().to(torch.float)\n    ref = ref.clone().to(torch.float)\n\n    def report_err(*args, **kwargs):\n        if not quiet:\n            print(*args, **kwargs)\n\n    # Deal with anomalies\n    def deal_with_anomalies(val: float):\n        ref_mask = (ref == val) if (val == val) else (ref != ref)\n        ans_mask = (ans == val) if (val == val) else (ans != ans)\n        ref[ref_mask] = 0.0\n        ans[ans_mask] = 0.0\n        if not torch.equal(ref_mask, ans_mask):\n            report_err(f\"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref\")\n            return False\n        return True\n    \n    anomalies_check_passed = True\n    anomalies_check_passed &= deal_with_anomalies(float(\"inf\"))\n    anomalies_check_passed &= deal_with_anomalies(float(\"-inf\"))\n    anomalies_check_passed &= deal_with_anomalies(float(\"nan\"))\n\n    cos_diff = get_cos_diff(ans, ref)\n    raw_abs_err = torch.abs(ans-ref)\n    raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))\n    rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)\n    abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)\n    pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)\n\n    if not anomalies_check_passed:\n        return False\n\n    if not pass_mask.all():\n        report_err(f\"`{name}` mismatch\")\n        max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item()\n        max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item()\n        def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:\n            result = []\n            for size in t.shape[::-1]:\n                result.append(pos % size)\n                pos = pos // size\n            assert pos == 0\n            return result[::-1]\n        report_err(f\"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}\")\n        report_err(f\"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}\")\n        report_err(f\"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)\")\n        report_err(f\"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})\")\n        return False\n    else:\n        if abs(cos_diff) > cos_diff_tol:\n            report_err(f\"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})\")\n            return False\n        return True\n    \ndef check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):\n    out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol))\n"
  },
  {
    "path": "tests/kernelkit/generate.py",
    "content": "import torch\n\ndef _get_new_non_contiguous_tensor_shape(shape):\n    \"\"\"\n    Get the expanded shape for a non-contiguous tensor.\n    The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1\n    \"\"\"\n    return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)]\n\ndef gen_non_contiguous_randn_tensor(shape, *args, **kwargs):\n    new_shape = _get_new_non_contiguous_tensor_shape(shape)\n    base_tensor = torch.randn(new_shape, *args, **kwargs)\n    slices = [slice(0, dim) for dim in shape]\n    return base_tensor[slices]\n\ndef gen_non_contiguous_tensor(shape, *args, **kwargs):\n    new_shape = _get_new_non_contiguous_tensor_shape(shape)\n    base_tensor = torch.empty(new_shape, *args, **kwargs)\n    slices = [slice(0, dim) for dim in shape]\n    return base_tensor[slices]\n\ndef non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:\n    new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device)\n    new_tensor[:] = tensor\n    return new_tensor\n"
  },
  {
    "path": "tests/kernelkit/precision.py",
    "content": "import torch\n\n_is_low_precision_mode_stack = []\n\nclass LowPrecisionMode:\n    def __init__(self, enabled: bool = True):\n        self.enabled = enabled\n\n    def __enter__(self):\n        global _is_low_precision_mode_stack\n        _is_low_precision_mode_stack.append(self.enabled)\n    \n    def __exit__(self, exc_type, exc_value, traceback):\n        global _is_low_precision_mode_stack\n        _is_low_precision_mode_stack.pop()\n\ndef is_low_precision_mode() -> bool:\n    global _is_low_precision_mode_stack\n    if len(_is_low_precision_mode_stack) == 0:\n        return False\n    return _is_low_precision_mode_stack[-1]\n\ndef optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor:\n    assert tensor.dtype == torch.float32, \"Input tensor must be of dtype torch.float32 for optional casting.\"\n    if is_low_precision_mode():\n        tensor_bf16 = tensor.to(torch.bfloat16)\n        tensor_fp32 = tensor_bf16.to(torch.float32)\n        return tensor_fp32\n    else:\n        return tensor\n"
  },
  {
    "path": "tests/kernelkit/utils.py",
    "content": "import os\nimport functools\n\ncolors = {\n    'RED_FG': '\\033[31m',\n    'GREEN_FG': '\\033[32m',\n    'CYAN_FG': '\\033[36m',\n    'GRAY_FG': '\\033[90m',\n    'YELLOW_FG': '\\033[33m',\n    'RED_BG': '\\033[41m',\n    'GREEN_BG': '\\033[42m',\n    'CYAN_BG': '\\033[46m',\n    'YELLOW_BG': '\\033[43m',\n    'GRAY_BG': '\\033[100m',\n    'CLEAR': '\\033[0m'\n}\n\ndef cdiv(a: int, b: int) -> int:\n    return (a + b - 1) // b\n\n@functools.lru_cache()\ndef is_using_profiling_tools() -> bool:\n    \"\"\"\n    Return whether we are running under profiling tools like nsys or ncu\n\n    NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it\n    \"\"\"\n    is_using_nsys = os.environ.get('NSYS_PROFILING_SESSION_ID') is not None\n    is_using_ncu = os.environ.get('NV_COMPUTE_PROFILER_PERFWORKS_DIR') is not None\n    is_using_compute_sanitizer = os.environ.get('NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN') is not None\n    return is_using_nsys or is_using_ncu or is_using_compute_sanitizer\n\ndef set_random_seed(seed: int):\n    import random\n    import numpy as np\n    import torch\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n\nclass Counter:\n    def __init__(self):\n        self.count = 0\n\n    def next(self) -> int:\n        self.count += 1\n        return self.count - 1\n"
  },
  {
    "path": "tests/lib.py",
    "content": "import dataclasses\nimport os\nimport enum\nfrom typing import List, Optional\nimport random\n\nimport torch\nimport kernelkit as kk\nimport flash_mla\n\nimport quant\n\nclass TestTarget(enum.Enum):\n    FWD = 0\n    DECODE = 1\n\n@dataclasses.dataclass\nclass ExtraTestParamForDecode:\n    b: int\n    is_varlen: bool\n    have_zero_seqlen_k: bool\n    extra_s_k: Optional[int] = None\n    extra_topk: Optional[int] = None\n    block_size: int = 64\n    extra_block_size: Optional[int] = None\n    have_extra_topk_length: bool = False\n    \n@dataclasses.dataclass\nclass TestParam:\n    s_q: int\n    s_kv: int\n    topk: int\n    h_q: int = 128\n    h_kv: int = 1\n    d_qk: int = 512\n    d_v: int = 512\n    seed: int = -1   # -1: to be filled automatically\n    check_correctness: bool = True\n    is_all_indices_invalid: bool = False    # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647)\n    num_runs: int = 10\n    have_attn_sink: bool = False\n    have_topk_length: bool = False\n    decode: Optional[ExtraTestParamForDecode] = None\n\n@dataclasses.dataclass\nclass RawTestParamForDecode:\n    \"\"\"\n    \"Flattened\" test parameters for decoding test\n    \n    In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a \"flattened\" version of test parameters for decoding test.\n    \"\"\"\n    b: int\n    h_q: int\n    s_q: int\n    h_kv: int\n    s_kv: int\n    is_varlen: bool\n    topk: int\n    is_all_indices_invalid: bool = False\n    have_zero_seqlen_k: bool = False\n    have_topk_length: bool = False\n    enable_attn_sink: bool = True\n    extra_s_k: Optional[int] = None\n    extra_topk: Optional[int] = None\n    block_size: int = 64\n    extra_block_size: Optional[int] = None\n    have_extra_topk_length: bool = False\n    d_qk: int = 576      # Q/K head dim (= dv + RoPE dim)\n    d_v: int = 512     # V head dim\n    check_correctness: bool = True\n    num_runs: int = 10\n    seed: int = -1\n\n    def to_test_param(self) -> TestParam:\n        return TestParam(\n            self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v,\n            self.seed, self.check_correctness,\n            self.is_all_indices_invalid,\n            self.num_runs,\n            self.enable_attn_sink,\n            self.have_topk_length,\n            decode = ExtraTestParamForDecode(\n                self.b, self.is_varlen, self.have_zero_seqlen_k,\n                self.extra_s_k, self.extra_topk,\n                self.block_size, self.extra_block_size, self.have_extra_topk_length\n            )\n        )\n    \n@dataclasses.dataclass\nclass Testcase:\n    p: TestParam\n    dOut: torch.Tensor  # [s_q, h_q, d_v]\n    q: torch.Tensor     # [s_q, h_q, d_qk]\n    kv: torch.Tensor    # [s_kv, h_kv, d_qk]\n    indices: torch.Tensor   # [s_q, h_kv, topk]\n    sm_scale: float\n    attn_sink: Optional[torch.Tensor]   # [h_q]\n    topk_length: Optional[torch.Tensor]  # [s_q]\n\ndef _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor:\n    \"\"\"\n    Generate random permutations in batch\n    The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds.\n    Values within each row are unique.\n    If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`.\n    \"\"\"\n    assert not torch.are_deterministic_algorithms_enabled()\n    torch.use_deterministic_algorithms(True)\n    perm_range_max = max(int(torch.max(perm_range).item()), perm_size)\n    rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32)\n    rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float(\"-inf\")    # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first\n    res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32)\n    if len(paddings) == 1:\n        res[res >= perm_range.view(batch_size, 1)] = paddings[0]\n    else:\n        fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32))\n        res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers)\n    torch.use_deterministic_algorithms(False)\n    return res\n\ndef generate_testcase(t: TestParam) -> Testcase:\n    kk.set_random_seed(t.seed)\n    q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10\n    kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10\n    do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10\n\n    q.clamp_(-10, 10)\n    kv.clamp_(-10, 10)\n    do.clamp_(-10, 10)\n    \n    invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647]\n    indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk)\n\n    if t.is_all_indices_invalid:\n        all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2\n        indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate)\n    indices = indices.to(q.device)\n\n    attn_sink = None\n    if t.have_attn_sink:\n        attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)\n        mask = torch.randn((t.h_q, ), dtype=torch.float32)\n        attn_sink[mask < -0.5] = float(\"-inf\")\n        attn_sink[mask > +0.5] = float(\"+inf\")\n\n    topk_length = None\n    if t.have_topk_length:\n        topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk)\n\n    q = kk.non_contiguousify(q)\n    kv = kk.non_contiguousify(kv)\n    do = kk.non_contiguousify(do)\n    indices = kk.non_contiguousify(indices)\n\n    return Testcase(\n        p=t,\n        dOut=do,\n        q=q,\n        kv=kv,\n        indices=indices,\n        sm_scale=0.5,   # Otherwise dK is too small compared to dV\n        attn_sink=attn_sink,\n        topk_length=topk_length\n    )\n\n\n@dataclasses.dataclass\nclass KVScope:\n    t: TestParam\n    cache_seqlens: torch.Tensor\n    block_table: torch.Tensor\n    blocked_k: torch.Tensor\n    abs_indices: torch.Tensor\n    indices_in_kvcache: torch.Tensor\n    topk_length: Optional[torch.Tensor]\n    blocked_k_quantized: Optional[torch.Tensor] = None\n\n    def quant_and_dequant_(self):\n        \"\"\"\n        For FP8 cases, we need to quantize the KV cache for Flash MLA.\n        Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error\n        \"\"\"\n        fp8_kvcache_layout = None\n        if self.t.d_qk == 576:\n            fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse\n        elif self.t.d_qk == 512:\n            assert self.abs_indices is not None\n            fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse\n        else:\n            assert False\n        self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout)\n        blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout)\n        self.blocked_k = blocked_k_dequantized\n\n    def get_kvcache_for_flash_mla(self) -> torch.Tensor:\n        \"\"\"\n        Return the quantized blocked_k for Flash MLA\n        \"\"\"\n        assert self.blocked_k_quantized is not None, \"Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`\"\n        return self.blocked_k_quantized\n    \n    def apply_perm(self, perm: torch.Tensor) -> \"KVScope\":\n        \"\"\"\n        Apply a batch permutation to this KVScope. Used for batch-invariance test\n        \"\"\"\n        new_kvscope = KVScope(\n            self.t,\n            self.cache_seqlens[perm],\n            self.block_table[perm],\n            self.blocked_k,\n            self.abs_indices[perm],\n            self.indices_in_kvcache[perm],\n            self.topk_length[perm] if self.topk_length is not None else None,\n            self.blocked_k_quantized\n        )\n        return new_kvscope\n    \n@dataclasses.dataclass\nclass TestcaseForDecode:\n    p: TestParam\n    q: torch.Tensor     # [b, s_q, h_q, d_qk]\n    attn_sink: Optional[torch.Tensor]   # [h_q]\n    sm_scale: float\n    kv_scope: KVScope\n    extra_kv_scope: Optional[KVScope]\n\ndef generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:\n    kk.set_random_seed(t.seed)\n    assert t.h_q % t.h_kv == 0\n    assert t.decode is not None\n\n    q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk))\n    q.clamp_(min=-1.0, max=1.0)\n\n    attn_sink = None\n    if t.have_attn_sink:\n        attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)\n        inf_mask = torch.randn((t.h_q, ), dtype=torch.float32)\n        attn_sink[inf_mask > 0.5] = float(\"inf\")\n        attn_sink[inf_mask < -0.5] = float(\"-inf\")\n\n    def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope:\n        b = t.decode.b  # type: ignore\n        cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu')\n        if is_varlen:\n            for i in range(b):\n                cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q)\n\n        if have_zero_seqlen:\n            zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0\n            cache_seqlens_cpu[zeros_mask] = 0\n\n        max_seqlen_alignment = 4 * block_size\n        max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment\n        cache_seqlens = cache_seqlens_cpu.cuda()\n\n        assert max_seqlen_pad % block_size == 0\n        block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)\n        block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)\n\n        blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10\n        blocked_k.clamp_(min=-1.0, max=1.0)\n    \n        abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32)\n        if is_all_indices_invalid:\n            abs_indices.fill_(-1)\n        else:\n            abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk)\n        indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size)\n\n        topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None\n\n        # Mask nonused KV as NaN\n        if have_topk_length:\n            indices_in_kvcache_masked = indices_in_kvcache.clone()\n            indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1\n        else:\n            indices_in_kvcache_masked = indices_in_kvcache\n        \n        blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk)\n        nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')\n        nonused_indices_mask[indices_in_kvcache_masked] = False\n        blocked_k[nonused_indices_mask, :, :] = float(\"nan\")\n        blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk)\n    \n        block_table = kk.non_contiguousify(block_table)\n        abs_indices = kk.non_contiguousify(abs_indices)\n        indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache)\n        return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)\n\n    kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)\n    kv_scope0.quant_and_dequant_()\n    if t.decode.extra_topk is not None:\n        if t.decode.extra_s_k is None:\n            t.decode.extra_s_k = t.decode.extra_topk*2\n        if t.decode.extra_block_size is None:\n            t.decode.extra_block_size = t.decode.block_size\n        kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)\n        kv_scope1.quant_and_dequant_()\n    else:\n        assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length\n        kv_scope1 = None\n    \n    sm_scale = t.d_qk ** -0.55\n\n    q = kk.non_contiguousify(q)\n    return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1)\n\n\ndef run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):\n    assert not return_p_sum\n    return flash_mla.flash_mla_sparse_fwd(\n        t.q, t.kv, t.indices,\n        sm_scale=t.sm_scale,\n        attn_sink=t.attn_sink,\n        topk_length=t.topk_length\n    )\n\ndef run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):\n    assert p.decode is not None\n    return flash_mla.flash_mla_with_kvcache(\n        t.q,\n        t.kv_scope.get_kvcache_for_flash_mla(),\n        None, None, p.d_v,\n        tile_scheduler_metadata, num_splits,\n\n        t.sm_scale, False, True,\n        t.kv_scope.indices_in_kvcache,\n        t.attn_sink,\n        t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None,\n        t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,\n        t.kv_scope.topk_length,\n        t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None\n    )\n\n\n@dataclasses.dataclass\nclass FlopsAndMemVolStatistics:\n    \"\"\"\n    FLOPs and memory volume statistics for prefilling\n    \"\"\"\n    fwd_flop: float\n    fwd_mem_vol: float\n\ndef count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics:\n    total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item()\n    indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv)\n    if t.topk_length is not None:\n        indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None]\n    num_valid_indices = indices_valid_mask.sum().item()\n\n    fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v)\n    fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2\n    return FlopsAndMemVolStatistics(\n        fwd_flop,\n        fwd_mem_vol,\n    )\n\n@dataclasses.dataclass\nclass FlopsAndMemVolStatisticsForDecode:\n    \"\"\"\n    FLOPs and memory volume statistics for decoding\n    \"\"\"\n    flop: float\n    mem_vol: float\n\ndef count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode:\n    assert p.decode\n    b = p.decode.b\n\n    def get_num_attended_tokens(kv_scope: KVScope) -> int:\n        topk = kv_scope.indices_in_kvcache.shape[-1]\n        if kv_scope.topk_length is None:\n            return b * p.s_q * topk\n        else:\n            return int(kv_scope.topk_length.sum().item()) * p.s_q\n        \n    def get_num_retrieved_tokens(kv_scope: KVScope) -> int:\n        if kv_scope.topk_length is None:\n            indices = kv_scope.indices_in_kvcache\n        else:\n            indices = kv_scope.indices_in_kvcache.clone()\n            batch, s_q, topk = indices.shape\n            mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1)\n            indices[mask] = -1\n        num_unique_tokens = indices.unique().numel()    # type: ignore\n        return num_unique_tokens\n\n    num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)\n    num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)\n\n    compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v)\n    kv_token_size = 656 if p.d_qk == 576 else 576   # Assume FP8 KV Cache\n    mem_vol = sum([\n        2 * b * p.s_q * p.h_q * p.d_qk, # Q\n        num_retrieved_tokens * kv_token_size,   # K\n        2 * b * p.s_q * p.h_q * p.d_v, # O\n    ])\n    return FlopsAndMemVolStatisticsForDecode(\n        compute_flop,\n        mem_vol\n    )\n\ndef is_no_cooldown() -> bool:\n    return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y']\n"
  },
  {
    "path": "tests/quant.py",
    "content": "import enum\nfrom typing import Tuple\n\nimport torch\n\nclass FP8KVCacheLayout(enum.Enum):\n    V32_FP8Sparse = 1\n    MODEL1_FP8Sparse = 2\n\n    def get_meta(self) -> Tuple[int, int, int, int, int]:\n        # Return: (d, d_nope, d_rope, tile_size, num_tiles)\n        return {\n            FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4),\n            FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7)\n        }[self]\n\ndef _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch.float32) -> torch.Tensor:\n    return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype)\n\ndef quantize_k_cache(\n    input_k_cache: torch.Tensor,    # (num_blocks, block_size, h_k, d)\n    kvcache_layout: FP8KVCacheLayout,\n) -> torch.Tensor:\n    \"\"\"\n    Quantize the k-cache\n    For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py\n    \"\"\"\n    d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()\n    assert input_k_cache.shape[-1] == d\n    num_blocks, block_size, h_k, _ = input_k_cache.shape\n    assert h_k == 1\n    input_k_cache = input_k_cache.squeeze(2)    # [num_blocks, block_size, d]\n    input_elem_size = input_k_cache.element_size()\n\n    if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:\n        bytes_per_token = d_nope + num_tiles*4 + input_elem_size*d_rope\n        result = torch.empty((num_blocks, block_size+1, bytes_per_token), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size, :]\n        result_k_nope_part = result[..., :d_nope]\n        result_k_scale_factor = result[..., d_nope: d_nope + num_tiles*4].view(torch.float32)\n        result_k_rope_part = result[..., d_nope + num_tiles*4:].view(input_k_cache.dtype)\n        result_k_rope_part[:] = input_k_cache[..., d_nope:]\n\n        for tile_idx in range(0, num_tiles):\n            cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]\n            cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)\n            result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv\n\n            cur_scale_factors_inv.unsqueeze_(-1)    # [num_blocks, block_size, 1]\n            cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)\n            result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope\n        \n        result = result.view(num_blocks, block_size, 1, -1)\n        return result\n    \n    elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:\n        bytes_per_token = d_nope + 2*d_rope + num_tiles + 1\n        size_per_block_padded = (block_size*bytes_per_token + 576-1) // 576 * 576\n        result = torch.empty((num_blocks, size_per_block_padded), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size*bytes_per_token]\n        result_k_nope_rope_part = result[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)\n        result_k_nope = result_k_nope_rope_part[:, :, :d_nope]  # [num_blocks, block_size, d_nope]\n        result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view(input_k_cache.dtype)  # [num_blocks, block_size, d_rope]\n        result_k_scale_factor = result[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu)   # [num_blocks, block_size, num_tiles]\n\n        result_k_rope[:] = input_k_cache[..., d_nope:]\n        for tile_idx in range(0, num_tiles):\n            cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size]\n            cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv)\n            result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to(torch.float8_e8m0fnu)\n\n            cur_scale_factors_inv = cur_scale_factors_inv.view(num_blocks, block_size, 1)\n            cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)\n            result_k_nope[:, :, tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope\n        \n        result = result.view(num_blocks, block_size, 1, -1)\n        return result\n\n    else:\n        raise NotImplementedError(f\"Unsupported kvcache_layout: {kvcache_layout}\")\n    \n\ndef dequantize_k_cache(\n    quant_k_cache: torch.Tensor,    # (num_blocks, block_size, 1, bytes_per_token)\n    kvcache_layout: FP8KVCacheLayout,\n) -> torch.Tensor:\n    \"\"\"\n    De-quantize the k-cache\n    \"\"\"\n    d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta()\n    num_blocks, block_size, h_k, _ = quant_k_cache.shape\n    assert h_k == 1\n    result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device)\n\n    if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse:\n        quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)\n\n        input_nope = quant_k_cache[..., :d_nope]\n        input_scale = quant_k_cache[..., d_nope:d_nope + num_tiles*4].view(torch.float32)\n        input_rope = quant_k_cache[..., d_nope + num_tiles*4:].view(torch.bfloat16)\n        result[..., d_nope:] = input_rope\n\n        for tile_idx in range(0, num_tiles):\n            cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)\n            cur_scales = input_scale[..., tile_idx].unsqueeze(-1)\n            result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales\n\n    elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse:\n        quant_k_cache = quant_k_cache.view(num_blocks, -1)  # [num_blocks, ...]  \n        input_nope_rope = quant_k_cache[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope)\n        input_nope = input_nope_rope[:, :, :d_nope]\n        input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16)\n        input_scale = quant_k_cache[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu)   # [num_blocks, block_size, num_tiles]\n\n        result[..., d_nope:] = input_rope\n        for tile_idx in range(0, num_tiles):\n            cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.bfloat16)\n            cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1)\n            result[..., tile_idx*tile_size: (tile_idx+1)*tile_size] = cur_nope * cur_scales\n            \n    else:\n        raise NotImplementedError(f\"Unsupported kvcache_layout: {kvcache_layout}\")\n    \n    result = result.view(num_blocks, block_size, 1, d)\n    return result\n\n\ndef abs_indices2indices_in_kvcache(\n    abs_indices: torch.Tensor,  # [b, s_q, topk]\n    block_table: torch.Tensor,  # [b, /]\n    block_size: int,\n) -> torch.Tensor:\n    \"\"\"\n    Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel\n    Equivalent to:\n    \n    b, s_q, topk = abs_indices.shape\n    indices_in_kvcache = torch.empty_like(abs_indices)\n    for i in range(b):\n        cur_abs_indices = abs_indices[i, :, :].clone()  # [s_q, topk]\n        invalid_mask = cur_abs_indices == -1\n        cur_abs_indices[invalid_mask] = 0\n        cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size\n        cur_indices_in_kvcache[invalid_mask] = -1\n        indices_in_kvcache[i] = cur_indices_in_kvcache\n    return indices_in_kvcache\n\n    \"\"\"\n    b, s_q, topk = abs_indices.shape\n    _, max_blocks_per_seq = block_table.shape\n\n    abs_indices = abs_indices.clone()\n    invalid_mask = abs_indices == -1\n    abs_indices[invalid_mask] = 0\n\n    real_block_idxs = block_table.view(-1).index_select(0, (abs_indices//block_size + torch.arange(0, b).view(b, 1, 1)*max_blocks_per_seq).view(-1))\n    indices_in_kvcache = real_block_idxs.view(b, s_q, topk)*block_size + abs_indices%block_size\n    indices_in_kvcache[invalid_mask] = -1\n\n    return indices_in_kvcache"
  },
  {
    "path": "tests/ref.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\n\nfrom lib import TestParam, Testcase, TestcaseForDecode, KVScope\n\ndef _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor:\n    if lse1 is None:\n        return lse0\n    else:\n        return torch.logsumexp(\n            torch.stack([\n                lse0.view(s_q, h_q),\n                lse1.broadcast_to(s_q, h_q)\n            ], dim=0),\n            dim=0\n        )\n        \ndef ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Returns:\n    - o: [s_q, h_q, dv]\n    - o_fp32: [s_q, h_q, dv]\n    - max_logits: [s_q, h_q]\n    - lse: [s_q, h_q]\n    \"\"\"\n    indices = t.indices.clone().squeeze(1)\n    if t.topk_length is not None:\n        mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1)   # [s_q, topk]\n        indices[mask] = -1\n    invalid_mask = (indices < 0) | (indices >= p.s_kv)    # [s_q, topk]\n    indices[invalid_mask] = 0\n\n    q = t.q.float()\n    gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float()   # [s_q, topk, d_qk]\n    P = (q @ gathered_kv.transpose(1, 2))   # [s_q, h_q, topk]\n    P *= t.sm_scale\n    P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float(\"-inf\")\n\n    orig_lse = torch.logsumexp(P, dim=-1)   # [s_q, h_q]\n    max_logits = P.max(dim=-1).values   # [s_q, h_q]\n\n    lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q)\n    if not torch.is_inference_mode_enabled():\n        lse_for_o = lse_for_o.clone()\n    lse_for_o[lse_for_o == float(\"-inf\")] = float(\"+inf\")   # So that corresponding O will be 0\n    s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))\n    out = s_for_o @ gathered_kv[..., :p.d_v]   # [s_q, h_q, dv]\n\n    lonely_q_mask = orig_lse == float(\"-inf\")   # [s_q, h_q]\n    orig_lse[lonely_q_mask] = float(\"+inf\")\n    return (out.to(torch.bfloat16), out, max_logits, orig_lse)\n\n\ndef ref_sparse_attn_decode(\n    p: TestParam,\n    t: TestcaseForDecode\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    A reference implementation of sparse decoding attention in PyTorch\n    \"\"\"\n    assert p.h_kv == 1\n    assert p.decode is not None\n    b = p.decode.b\n\n    def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]:\n        assert kv_scope.indices_in_kvcache is not None\n        topk = kv_scope.indices_in_kvcache.size(-1)\n        indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain\n        gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk)  # [b, s_q, topk, d]\n        invalid_mask = kv_scope.indices_in_kvcache == -1\n        if kv_scope.topk_length is not None:\n            invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1)\n        return gathered_kv, invalid_mask\n    \n    gathered_kv, invalid_mask = process_kv_scope(t.kv_scope)\n    if t.extra_kv_scope is not None:\n        gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope)\n        gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2)  # [b, s_q, topk+extra_topk, d]\n        invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2)   # [b, s_q, topk+extra_topk]\n\n    gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float()\n    gathered_kv[gathered_kv != gathered_kv] = 0.0\n    q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk)\n    attn_weight = q @ gathered_kv.transpose(-1, -2)  # [t.b*t.s_q, t.h_q, topk+extra_topk]\n    attn_weight *= t.sm_scale\n    attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float(\"-inf\")\n    lse = attn_weight.logsumexp(dim=-1)  # [t.b*t.s_q, t.h_q]\n    attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1))\n    output = attn_weight @ gathered_kv[..., :p.d_v]    # [t.b*t.s_q, t.h_q, t.dv]\n    output = output.view(b, p.s_q, p.h_q, p.d_v)\n    lse = lse.view(b, p.s_q, p.h_q)\n\n    # Attention sink\n    if t.attn_sink is not None:\n        output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1)\n\n    # Correct for q tokens which has no attendable k\n    lonely_q_mask = (lse == float(\"-inf\"))\n    output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0\n    lse[lonely_q_mask] = float(\"+inf\")\n\n    return output.to(torch.bfloat16), lse.transpose(1, 2)"
  },
  {
    "path": "tests/test_flash_mla_dense_decoding.py",
    "content": "import argparse\nimport math\nimport random\nimport dataclasses\nfrom typing import Tuple\n\nimport torch\n\nimport kernelkit as kk\nimport flash_mla\n\n@dataclasses.dataclass\nclass TestParam:\n    b: int    # Batch size\n    s_q: int  # Number of queries for one request\n    s_k: int  # Seq len, or mean seq len if varlen == True\n    is_varlen: bool\n    is_causal: bool\n    test_performance: bool = True\n    have_zero_seqlen_k: bool = False\n    block_size: int = 64\n    h_q: int = 128    # Number of q heads\n    h_kv: int = 1     # Number of kv heads\n    d: int = 576      # Q/K head dim (= dv + RoPE dim)\n    dv: int = 512     # V head dim\n    seed: int = 0\n\n\ndef generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Generate test data from a given configuration\n    Return: [cache_seqlens, q, block_table, blocked_k]\n    Pay attention: This function changes the random seed\n    \"\"\"\n    random.seed(t.seed)\n    torch.manual_seed(t.seed)\n    torch.cuda.manual_seed(t.seed)\n    torch.backends.cudnn.deterministic = True\n\n    assert t.h_q % t.h_kv == 0\n\n    cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu')\n    if t.is_varlen:\n        for i in range(t.b):\n            cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q)\n\n    if t.have_zero_seqlen_k:\n        zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0\n        cache_seqlens_cpu[zeros_mask] = 0\n\n    max_seqlen = int(cache_seqlens_cpu.max().item())\n    max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256\n    cache_seqlens = cache_seqlens_cpu.cuda()\n\n    q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10\n    q.clamp_(min=-1.0, max=1.0)\n\n    block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)\n    block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1)\n    blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10\n    blocked_k.clamp_(min=-1.0, max=1.0)\n\n    for i in range(t.b):\n        cur_len = int(cache_seqlens_cpu[i].item())\n        cur_num_blocks = kk.cdiv(cur_len, t.block_size)\n        blocked_k[block_table[i][cur_num_blocks:]] = float(\"nan\")\n        if cur_len % t.block_size != 0:\n            blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float(\"nan\")\n        block_table[i][cur_num_blocks:] = 2147480000\n    return cache_seqlens, q, block_table, blocked_k\n\n\ndef reference_torch(\n    cache_seqlens: torch.Tensor,    # [batch_size]\n    block_table: torch.Tensor,      # [batch_size, ?]\n    q: torch.Tensor,    # [batch_size, s_q, h_q, d]\n    blocked_k: torch.Tensor,    # [?, block_size, h_kv, d]\n    dv: int,\n    is_causal: bool,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    A reference implementation in PyTorch\n    \"\"\"\n\n    def scaled_dot_product_attention(\n        batch_idx: int,\n        query: torch.Tensor,    # [h_q, s_q, d]\n        kv: torch.Tensor,      # [h_kv, s_k, d]\n        dv: int,\n        is_causal,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        h_q = query.size(0)\n        h_kv = kv.size(0)\n        s_q = query.shape[-2]\n        s_k = kv.shape[-2]\n        query = query.float()\n        kv = kv.float()\n        if h_kv != 1:\n            kv = kv.repeat_interleave(h_q // h_kv, dim=0)\n        kv[kv != kv] = 0.0\n        attn_weight = query @ kv.transpose(-2, -1)  # [h_q, s_q, s_k]\n        if is_causal and query.size(1) > 1:\n            mask = torch.ones(s_q, s_k, dtype=torch.bool)\n            if is_causal:\n                mask = mask.tril(diagonal=s_k - s_q)\n            attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)\n            attn_bias.masked_fill_(mask.logical_not(), float(\"-inf\"))\n            attn_weight += attn_bias.to(q.dtype)\n        attn_weight /= math.sqrt(query.size(-1))\n        lse = attn_weight.logsumexp(dim=-1)  # [h_q, s_q]\n        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)\n        output = attn_weight @ kv[..., :dv]    # [h_q, s_q, dv]\n        # Correct for q tokens which has no attendable k\n        lonely_q_mask = (lse == float(\"-inf\"))\n        output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0\n        lse[lonely_q_mask] = float(\"+inf\")\n\n        return output, lse\n\n    b, s_q, h_q, d = q.size()\n    block_size = blocked_k.size(1)\n    h_kv = blocked_k.size(2)\n    cache_seqlens_cpu = cache_seqlens.cpu()\n    out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)\n    lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)\n    for i in range(b):\n        cur_len = int(cache_seqlens_cpu[i].item())\n        cur_num_blocks = kk.cdiv(cur_len, block_size)\n        cur_block_indices = block_table[i][0: cur_num_blocks]\n        cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]\n        cur_out, cur_lse = scaled_dot_product_attention(\n            i,\n            q[i].transpose(0, 1),\n            cur_kv.transpose(0, 1),\n            dv,\n            is_causal\n        )\n        out_ref[i] = cur_out.transpose(0, 1)\n        lse_ref[i] = cur_lse\n    out_ref = out_ref.to(q.dtype)\n    return out_ref, lse_ref\n\n\n@torch.inference_mode()\ndef test_flash_mla(t: TestParam):\n    print('-------------------------------')\n    print(f\"Running on {t}...\")\n\n    # Generating test data\n    torch.cuda.synchronize()\n    cache_seqlens, q, block_table, blocked_k, = generate_test_data(t)\n\n    tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()\n\n    def run_flash_mla():\n        return flash_mla.flash_mla_with_kvcache(\n            q,\n            blocked_k,\n            block_table,\n            cache_seqlens,\n            t.dv,\n            tile_scheduler_metadata,\n            num_splits,\n            causal=t.is_causal\n        )\n\n    out_ans, lse_ans = run_flash_mla()\n    out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal)\n    is_correct = True\n    is_correct &= kk.check_is_allclose(\"out\", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)\n    is_correct &= kk.check_is_allclose(\"lse\", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)\n    assert is_correct\n\n    if t.test_performance:\n        time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time(\"flash_fwd_splitkv_mla_kernel\")\n\n        mean_attended_seqlens = cache_seqlens.float().mean().item()\n        compute_volume_flop = t.b * t.h_q * t.s_q * sum([\n            2 * t.d * mean_attended_seqlens,   # Q * K^T\n            2 * mean_attended_seqlens * t.dv,  # attention * V\n        ])\n        q_elem_size = torch.bfloat16.itemsize\n        kv_token_size = t.d * torch.bfloat16.itemsize\n        memory_volume_B = t.b * sum([\n            t.s_q * t.h_q * (t.d * q_elem_size),    # Q\n            mean_attended_seqlens * t.h_kv * kv_token_size,    # K/V\n            t.s_q * t.h_q * (t.dv * q_elem_size),   # Output\n        ])\n        achieved_tflops = compute_volume_flop / time_usage / 1e12\n        achieved_gBps = memory_volume_B / time_usage / 1e9\n\n        print(f\"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s\")\n\n\ndef main(torch_dtype):\n    device = torch.device(\"cuda:0\")\n    torch.set_default_dtype(torch_dtype)\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n\n    cc_major, cc_minor = torch.cuda.get_device_capability()\n    assert cc_major == 9, \"Dense MLA decoding is only supported on sm90 (Hopper) currently.\"\n\n    correctness_cases = [\n        TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv)\n        for b in [1, 2, 6, 64]\n        for s_q in [1, 2, 4]\n        for s_k in [20, 140, 4096]\n        for h_q in [1, 3, 9, 63, 64, 126, 128]\n        for h_kv in [1, 2, 3, 8]\n        for is_varlen in [False, True]\n        for is_causal in [False, True]\n        if h_q % h_kv == 0\n    ]\n\n    corner_cases = [\n        # Cases where some kv cache have zero length\n        TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv)\n        for h_q in [1, 3, 9, 63, 64, 126, 128]\n        for h_kv in [1, 2, 3, 8]\n        for is_causal in [False, True]\n        if h_q % h_kv == 0\n    ]\n\n    performance_cases = [\n        TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True)\n        for is_causal in [False, True]\n        for s_q in [1, 2]\n        for s_k in [4096, 8192, 16384, 32768]\n    ]\n\n    testcases = correctness_cases + corner_cases + performance_cases\n\n    for testcase in testcases:\n        test_flash_mla(testcase)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        choices=[\"bf16\", \"fp16\"],\n        default=\"bf16\",\n        help=\"Data type to use for testing (bf16 or fp16)\",\n    )\n\n    args = parser.parse_args()\n\n    torch_dtype = torch.bfloat16\n    if args.dtype == \"fp16\":\n        torch_dtype = torch.float16\n\n    main(torch_dtype)"
  },
  {
    "path": "tests/test_flash_mla_sparse_decoding.py",
    "content": "import time\nimport dataclasses\nfrom typing import Tuple, List, Dict, Optional\nimport copy\n\nimport rich.console\nimport rich.table\n\nimport torch\nimport kernelkit as kk\n\nimport flash_mla\n\nimport lib\nfrom lib import TestParam\nfrom lib import RawTestParamForDecode as RawTestParam\nimport ref\n\n\"\"\"\nGenerate testcase for unit test\n\"\"\"\n\ndef gen_testcase() -> List[RawTestParam]:\n    correctness_cases = []\n    corner_cases = []\n    for d_qk in [576, 512]:\n        for have_extra_k in ([False, True] if d_qk == 512 else [False]):\n            for have_extra_topk_len in ([False, True] if have_extra_k else [False]):\n                for have_topk_len in ([False, True] if d_qk == 512 else [False]):\n                    for h_q in [64, 128]:\n                        cur_correctness_cases = [\n                            RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,\n                                        have_topk_length=have_topk_len,\n                                        enable_attn_sink=True,\n                                        extra_s_k=extra_s_k,\n                                        extra_topk=extra_topk,\n                                        block_size=block_size,\n                                        extra_block_size=extra_block_size,\n                                        have_extra_topk_length=have_extra_topk_len,\n                                        d_qk=d_qk,\n                                        check_correctness=True,\n                                        num_runs=0)\n                            for (s_k, topk, block_size) in [\n                                (512, 64, 2),\n                                (512, 64, 64),\n                                (512, 64, 69),\n                                (1024, 576, 2),\n                                (1024, 576, 61),\n                                (2046, 2048, 2),\n                                (2046, 2048, 64),\n                                (2046, 2048, 576)\n                            ]\n                            for (extra_s_k, extra_topk, extra_block_size) in ([\n                                (512, 64, 2),\n                                (512, 64, 64),\n                                (512, 64, 69),\n                                (1024, 576, 2),\n                                (1024, 576, 61),\n                                (2046, 2048, 2),\n                                (2046, 2048, 64),\n                                (2046, 2048, 576)\n                            ] if have_extra_k else [(None, None, None)])\n                            for b in [4, 74, 321]\n                            for s_q in [1, 3]\n                            for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])\n                        ]\n                        correctness_cases.extend(cur_correctness_cases)\n\n                        cur_corner_cases = [\n                            RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,\n                                        is_all_indices_invalid=is_all_indices_invalid,\n                                        have_zero_seqlen_k=have_zero_seqlen_k,\n                                        have_topk_length=have_topk_len,\n                                        enable_attn_sink=enable_attn_sink,\n                                        extra_s_k=extra_s_k,\n                                        extra_topk=extra_topk,\n                                        block_size=block_size,\n                                        extra_block_size=extra_block_size,\n                                        have_extra_topk_length=have_extra_topk_len,\n                                        d_qk=d_qk,\n                                        check_correctness=True,\n                                        num_runs=0,\n                            )\n                            for (s_k, topk, block_size) in [\n                                (512, 64, 61),\n                                (650, 576, 53),\n                            ]\n                            for (extra_s_k, extra_topk, extra_block_size) in ([\n                                (512, 64, 61),\n                                (650, 576, 53),\n                            ] if have_extra_k else [(None, None, None)])\n                            for b in [4, 74, 321]\n                            for s_q in [3]\n                            for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])\n                            for is_all_indices_invalid in [True, False]\n                            for have_zero_seqlen_k in [True, False]\n                            for enable_attn_sink in [True, False]\n                            if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink)\n                        ]\n                        corner_cases.extend(cur_corner_cases)\n\n    base_and_bszs = [\n        # V3.2\n        (RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]),\n        # MODEL1 CONFIG1\n        (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),\n        # MODEL1 CONFIG2\n        (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),\n        # MODEL1 CONFIG3\n        (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),\n        # MODEL1 CONFIG4\n        (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),\n    ]\n    performance_cases = [\n        # Production cases\n        dataclasses.replace(base, b=b)\n        for base, bszs in base_and_bszs\n        for b in bszs\n    ] + [\n        # Peak perf cases\n        RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)\n        for h_q in [64, 128]\n        for d_qk in [512, 576]\n    ]\n\n    return correctness_cases + corner_cases + performance_cases\n\n\n@dataclasses.dataclass\nclass Result:\n    is_correct: bool\n    compute_memory_ratio: float\n    time_usage_per_us: float\n    splitkv_time_usage_us: float\n    combine_time_usage_us: float\n    achieved_tflops: float\n    achieved_gBps: float\n\n_counter = kk.Counter()\n\n@torch.inference_mode()\ndef test_flash_mla(p: TestParam) -> Result:\n    if p.seed == -1:\n        global _counter\n        p.seed = _counter.next()\n    assert p.decode\n\n    print(\"================\")\n    print(f\"Running on {p}\")\n    torch.cuda.empty_cache()\n\n    t = lib.generate_testcase_for_decode(p)\n\n    tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()\n    def run_decode():\n        return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None)\n    \n    # We first run the kernel once to generate output data for the correctness test\n    # We must do this first, otherwise when allocating tensors for storing answers,\n    # it may re-use memory that contains the correct answer, leading to false positives\n    if p.check_correctness:\n        torch.cuda.synchronize()\n        out_ans, lse_ans = run_decode()\n        torch.cuda.synchronize()\n        # torch.set_printoptions(profile='full')\n        # print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7])\n    \n    # We run the performance test before generating the answer for the correctness test to avoid interference\n    performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    if p.num_runs == 0:\n        performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    else:\n        result = kk.bench_kineto(run_decode, p.num_runs)\n\n        splitkv_kernel_name = \"flash_fwd_splitkv_mla_fp8_sparse_kernel\"\n        combine_kernel_name = \"flash_fwd_mla_combine_kernel\"\n        \n        # Get individual kernel time usages\n        kernel_time_usages_us: Dict[str, Optional[float]] = {}\n        def pick_kernel_time_usage(kernel_name: str):\n            t = [kernel_name in s for s in result.get_kernel_names()]\n            if any(t):\n                assert sum(t) == 1\n                kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6\n            else:\n                kernel_time_usages_us[kernel_name] = None\n        pick_kernel_time_usage(splitkv_kernel_name)\n        pick_kernel_time_usage(combine_kernel_name)\n\n        # Get E2E time usages\n        def have_kernel(name: str):\n            return kernel_time_usages_us[name] is not None\n        \n        if kk.is_using_profiling_tools():\n            e2e_time_usage_us = 1e6\n        else:\n            assert have_kernel(splitkv_kernel_name)\n            if have_kernel(combine_kernel_name):\n                e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6\n            else:\n                e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name]\n        assert e2e_time_usage_us is not None\n\n        flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)\n\n        e2e_time_usage_s = e2e_time_usage_us / 1e6\n        theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol\n        achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12\n        achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9\n        def print_kernel_time_usage(name: str, short_name: str):\n            if kernel_time_usages_us[name] is not None:\n                print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')\n        print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')\n        print(f'Time (per): {e2e_time_usage_us:.1f} us')\n        print_kernel_time_usage(splitkv_kernel_name, \"Splitkv\")\n        print_kernel_time_usage(combine_kernel_name, \"Combine\")\n        print(f'TFlops: {achieved_tflops:.1f}')\n        print(f'GB/s: {achieved_gBps:.0f}')\n\n        performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)\n    \n    is_correct = True\n    if p.check_correctness:\n        torch.cuda.synchronize()\n        with torch.profiler.record_function(\"reference_flash_mla\"):\n            out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)\n\n        is_out_correct = kk.check_is_allclose(\"out\", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)\n        is_lse_correct = kk.check_is_allclose(\"lse\", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)\n        is_correct &= is_out_correct and is_lse_correct\n\n    performance_result.is_correct = is_correct\n    return performance_result\n\n\ndef main():\n    dtype = torch.bfloat16\n    device = torch.device(\"cuda:0\")\n    torch.set_default_dtype(dtype)\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n    torch.set_float32_matmul_precision('high')\n    torch.set_num_threads(32)\n\n    raw_testcases = gen_testcase()\n    testcases = [t.to_test_param() for t in raw_testcases]\n\n    print(f\"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}\")\n\n    is_no_cooldown = lib.is_no_cooldown()\n    num_testcases_len = len(str(len(testcases)))\n    failed_cases = []\n    results: List[Tuple[TestParam, Result]] = []\n    for testcase_idx, testcase in enumerate(testcases):\n        if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown:\n            time.sleep(0.3) # Cooldown\n        print(f\"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%]  \", end='')\n        result = test_flash_mla(testcase)\n        results.append((testcase, result))\n        if not result.is_correct:\n            failed_cases.append(testcase)\n            import sys\n            sys.exit(1)\n\n    console = rich.console.Console(width=120)\n    table = rich.table.Table(show_header=True, header_style=\"bold cyan\")\n    table.add_column(\"topk\")\n    table.add_column(\"Bsz\")\n    table.add_column(\"h_q&k\")\n    table.add_column(\"sq\")\n    table.add_column(\"sk\")\n    table.add_column(\"d_qk\")\n    table.add_column(\"Feats\")\n    table.add_column(\"C/M\")\n    table.add_column(\"TFlops\")\n    table.add_column(\"GBps\")\n    table.add_column(\"us\")\n    table.add_column(\" \")\n\n    for testcase, result in results:\n        assert testcase.decode\n        topk_str = f\"{testcase.topk}\" if testcase.decode.extra_topk is None else f\"{testcase.topk}+{testcase.decode.extra_topk}\"\n        table.add_row(\n            topk_str,\n            str(testcase.decode.b),\n            f\"{testcase.h_q:3d} {testcase.h_kv}\",\n            str(testcase.s_q),\n            str(testcase.s_kv),\n            str(testcase.d_qk),\n            \" V\"[testcase.decode.is_varlen] + \" L\"[testcase.have_topk_length] + \" E\"[testcase.decode.have_extra_topk_length],\n            f\"{result.compute_memory_ratio:3.0f}\",\n            f\"{result.achieved_tflops:3.0f}\",\n            f\"{result.achieved_gBps:4.0f}\",\n            f\"{result.time_usage_per_us:4.1f}\",\n            \"\" if result.is_correct else \"X\"\n        )\n    console.print(table)\n\n    def geomean(l) -> float:\n        import numpy\n        return numpy.exp(numpy.mean(numpy.log(l)))\n    \n    num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True)\n    num_correctness_cases = sum([1 for t in testcases if t.check_correctness])\n    if num_correct_testcases == num_correctness_cases:\n        print(f\"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}\")\n    else:\n        print(f\"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}\")\n        for t in failed_cases:\n            print(f\"\\t{t},\")\n\n    valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1]\n    if len(valid_achieved_tflops) > 0:\n        achieved_tflops_geomean = geomean(valid_achieved_tflops)    # > 0.1 to prune out correctness cases\n        print(f\"TFlops     geomean: {achieved_tflops_geomean:.1f}\")\n    \n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/test_flash_mla_sparse_prefill.py",
    "content": "import time\nimport sys\n\nimport torch\nimport kernelkit as kk\n\nfrom lib import TestParam\nimport lib\nimport ref\n\n_counter = kk.Counter()\n\n@torch.inference_mode()\ndef run_test(p: TestParam) -> bool:\n    if p.seed == -1:\n        global _counter\n        p.seed = _counter.next()\n\n    print(\"================\")\n    print(f\"Running on {p}\")\n    torch.cuda.empty_cache()\n\n    t = lib.generate_testcase(p)\n    torch.cuda.synchronize()\n    \n    def run_prefill():\n        return lib.run_flash_mla_sparse_fwd(p, t, False)\n    \n    prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill()\n    torch.cuda.synchronize()\n\n    if p.num_runs > 0:\n        flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)\n        prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time(\"sparse_attn_fwd\")\n        prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12\n        prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12\n        print(f\"Prefill:  {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps\")\n\n    if p.check_correctness:\n        torch.cuda.synchronize()\n        ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t)\n        ref_lse[ref_lse == float(\"-inf\")] = float(\"+inf\")\n        torch.cuda.synchronize()\n\n        is_correct = True\n        is_correct &= kk.check_is_allclose(\"out\", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)\n        is_correct &= kk.check_is_allclose(\"max_logits\", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)\n        is_correct &= kk.check_is_allclose(\"lse\", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)\n\n        return is_correct\n    else:\n        return True\n\n\nif __name__ == '__main__':\n    device = torch.device(\"cuda:0\")\n    torch.set_default_dtype(torch.bfloat16)\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n    torch.set_float32_matmul_precision('high')\n\n    correctness_cases = [\n        # Regular shapes\n        TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)\n        for d_qk in [512, 576]\n        for h_q in [\n            128, 64\n        ]\n        for s_kv, topk in [\n            # Regular shapes\n            (128, 128),\n            (256, 256),\n            (512, 512),\n\n            # Irregular shapes\n            (592, 128),\n            (1840, 256),\n            (1592, 384),\n            (1521, 512),\n\n            # Irregular shapes with OOB TopK\n            (95, 128),\n            (153, 256),\n            (114, 384),\n        ]\n        for s_q in [\n            1, 62, 213\n        ]\n    ]\n\n    correctness_cases_with_features = [\n        TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)\n        for d_qk in [512, 576]\n        for h_q in [\n            128, 64\n        ]\n        for s_kv, topk in [\n            (592, 128),\n            (1840, 256),\n            (1592, 384),\n            (1521, 512),\n\n            (95, 128),\n            (153, 256),\n            (114, 384),\n        ]\n        for s_q in [62, 213]\n        for have_sink_lse in [False, True]\n        for have_attn_sink in [False, True]\n        for have_topk_length in [False, True]\n    ]\n\n    corner_cases = [\n        TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)\n        for d_qk in [512, 576]\n        for h_q in [\n            128, 64\n        ]\n        for s_q, s_kv, topk in [\n            (1, 128, 128),\n            (1, 256, 256),\n            (1234, 4321, 4096),\n            (4096, 2048, 2048)\n        ]\n    ] + [\n        # In these cases, some blocks may not have any valid topk indices\n        TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)\n        for d_qk in [512, 576]\n        for h_q in [\n            128, 64\n        ]\n        for s_kv, topk in [\n            (32, 2048),\n            (64, 8192)\n        ]\n        for s_q in [1, 1024]\n    ] + [\n        # In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape\n        TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)\n        for d_qk in [512, 576]\n        for h_q in [\n            128, 64\n        ]\n    ]\n\n    performance_case_templates = [\n        # V3.2\n        (576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),\n        # MODEL1 CONFIG1\n        (512, 64, 512, [8192, 32768, 49152, 65536]),\n        # MODEL1 CONFIG2\n        (512, 128, 1024, [8192, 32768, 49152, 65536]),\n    ]\n\n    performance_cases = [\n        TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)\n        for (d_qk, h_q, topk, s_kv_list) in performance_case_templates\n        for s_q in [4096]\n        for s_kv in s_kv_list\n    ]\n\n    testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases\n\n    is_no_cooldown = lib.is_no_cooldown()\n    failed_cases = []\n    for test in testcases:\n        if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown:\n            time.sleep(0.3)\n        is_correct = run_test(test)\n        if not is_correct:\n            failed_cases.append(test)\n    \n    if len(failed_cases) > 0:\n        print(f\"\\033[31m\\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\\033[0m\")\n        for case in failed_cases:\n            print(f\"    {case}\")\n        sys.exit(1)\n    else:\n        print(f\"\\033[32m\\033[1mAll {len(testcases)} cases passed!\\033[0m\")\n\n"
  },
  {
    "path": "tests/test_fmha_sm100.py",
    "content": "import random\n\nimport torch\nfrom torch.utils.checkpoint import checkpoint\nimport triton\n\nfrom flash_mla import flash_attn_varlen_func\nfrom kernelkit import check_is_allclose\n\ndef get_window_size(causal, window):\n    if window > 0:\n        window_size = (window - 1, 0) if causal else (window - 1, window - 1)\n    else:\n        window_size = (-1, -1)\n    return window_size\n\n\ndef get_attn_bias(s_q, s_k, causal, window):\n    attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32)\n    if causal:\n        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n    if window > 0:\n        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window)\n        attn_bias.masked_fill_(temp_mask, float(\"-inf\"))\n        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n    return attn_bias\n\n\ndef sdpa(query, key, value, attn_bias, softmax_scale=None):\n    query = query.float().transpose(-3, -2)\n    key = key.float().transpose(-3, -2)\n    value = value.float().transpose(-3, -2)\n    key = key.repeat_interleave(h // h_k, dim=-3)\n    value = value.repeat_interleave(h // h_k, dim=-3)\n    if softmax_scale is None:\n        softmax_scale = query.shape[-1] ** (-0.5)\n    attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale\n    attn_weight += attn_bias\n    lse = attn_weight.logsumexp(dim=-1)\n    attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)\n    return attn_weight.to(query.dtype) @ value, lse\n\n\ndef sdpa_checkpoint(*args, **kwargs):\n    return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)\n\n\ndef test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, check_correctness: bool = True):\n    print(f\"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}, {has_bwd=}, {check_correctness=}\")\n    torch.manual_seed(0)\n    random.seed(0)\n\n    seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32)\n    seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32)\n\n    if varlen:\n        for i in range(b):\n            seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1)\n        for i in range(b):\n            seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item())\n    cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32)\n    cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32)\n    total_q = seqlens_q.sum().item()\n    total_k = seqlens_k.sum().item()\n    max_seqlen_q = seqlens_q.max().item()\n    max_seqlen_k = seqlens_k.max().item()\n    total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(),\n                             causal, window) == 0).sum().item() for i in range(b)])\n    # print(f\"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}\")\n\n    q = torch.randn(total_q, h, d) / 10\n    k = torch.randn(total_k, h_k, d) / 10\n    v = torch.randn(total_k, h_k, dv) / 10\n    grad_out = torch.randn(total_q, h, dv) / 10\n    softmax_scale = (d + 100) ** (-0.5)\n\n    q1 = q.clone().requires_grad_()\n    k1 = k.clone().requires_grad_()\n    v1 = v.clone().requires_grad_()\n\n    if check_correctness:\n        q2 = q.clone().requires_grad_()\n        k2 = k.clone().requires_grad_()\n        v2 = v.clone().requires_grad_()\n\n    def flash_attn():\n        q1.grad = k1.grad = v1.grad = None\n        kwargs = {}\n        if causal:\n            kwargs[\"causal\"] = causal\n        if window != 0:\n            kwargs[\"window_size\"] = get_window_size(causal, window)\n        return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,\n                                      max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs)\n\n    def torch_attn():\n        q2.grad = k2.grad = v2.grad = None\n        out = []\n        lse = []\n        for i in range(b):\n            OUT, LSE = sdpa_checkpoint(\n                q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()],\n                k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()],\n                v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()],\n                attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window),\n                softmax_scale=softmax_scale,\n            )\n            out.append(OUT.transpose(-3, -2))\n            lse.append(LSE.transpose(-2, -1))\n        out = torch.cat(out)\n        lse = torch.cat(lse)\n        return out, lse\n\n    out_flash, lse_flash = flash_attn()\n    if has_bwd:\n        out_flash.backward(grad_out, retain_graph=True)\n        _dq1 = q1.grad.clone()\n        dk1 = k1.grad.clone()\n        dv1 = v1.grad.clone()\n\n    if check_correctness:\n        out_torch, lse_torch = torch_attn()\n        assert check_is_allclose(\"out\", out_flash.float(), out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)\n        assert check_is_allclose(\"lse\", lse_flash.float(), lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536)\n\n        if has_bwd:\n            out_torch.backward(grad_out, retain_graph=True)\n            assert check_is_allclose(\"dq\", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)\n            assert check_is_allclose(\"dk\", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)\n            assert check_is_allclose(\"dv\", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6)\n\n    def forward():\n        return flash_attn()\n\n    def backward():\n        q1.grad = k1.grad = v1.grad = None\n        out_flash.backward(grad_out, retain_graph=True)\n\n    for _ in range(5):\n        out, lse = forward()\n        assert torch.equal(out, out_flash), \"out deterministic check failed!\"\n        assert torch.equal(lse, lse_flash), \"lse deterministic check failed!\"\n        if has_bwd:\n            backward()\n            # assert torch.equal(q1.grad, dq1), \"dq deterministic check failed!\"\n            assert torch.equal(k1.grad, dk1), \"dk deterministic check failed!\"\n            assert torch.equal(v1.grad, dv1), \"dv deterministic check failed!\"\n\n    def timer(func, name):\n        t = triton.testing.do_bench(func, warmup=2, rep=3)\n        FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == \"fwd\" else ((d * 3 + dv * 2)))\n        print(f\"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}\")\n        return t\n\n    timer(forward, \"fwd\")\n    if has_bwd:\n        timer(backward, \"bwd\")\n\n\nif __name__ == \"__main__\":\n    dtype = torch.bfloat16\n    torch.set_default_dtype(dtype)\n    device = torch.device(\"cuda:0\")\n    torch.set_default_device(device)\n    torch.cuda.set_device(device)\n    torch.set_float32_matmul_precision(\"high\")\n\n    b = 2\n    window = 0\n    has_bwd = False\n\n    for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]:\n        for varlen in [False, True]:\n            for (h, h_k) in [(128, 128), (32, 4)]:\n                if h != h_k:\n                    has_bwd = False\n                else:\n                    has_bwd = True\n                for (d, dv) in [(128, 128), (192, 128)]:\n                    for causal in [False, True]:\n                        skip_correctness_check = mean_sq == 8192 and mean_sk == 8192 and h == 128\n                        test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, not skip_correctness_check)\n"
  }
]