Showing preview only (1,140K chars total). Download the full file or copy to clipboard to get everything.
Repository: deepseek-ai/FlashMLA
Branch: main
Commit: 47c35a712362
Files: 129
Total size: 1.1 MB
Directory structure:
gitextract_clsc5nbn/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── benchmark/
│ ├── bench_flash_mla.py
│ └── visualize.py
├── csrc/
│ ├── api/
│ │ ├── api.cpp
│ │ ├── common.h
│ │ ├── dense_decode.h
│ │ ├── dense_fwd.h
│ │ ├── sparse_decode.h
│ │ └── sparse_fwd.h
│ ├── defines.h
│ ├── kerutils/
│ │ └── include/
│ │ └── kerutils/
│ │ ├── common/
│ │ │ └── common.h
│ │ ├── device/
│ │ │ ├── common.h
│ │ │ ├── device.cuh
│ │ │ ├── sm100/
│ │ │ │ ├── gemm.cuh
│ │ │ │ ├── helpers.cuh
│ │ │ │ ├── intrinsics.cuh
│ │ │ │ └── tma_cta_group2_nosplit.cuh
│ │ │ ├── sm80/
│ │ │ │ ├── helpers.cuh
│ │ │ │ └── intrinsics.cuh
│ │ │ └── sm90/
│ │ │ ├── helpers.cuh
│ │ │ └── intrinsics.cuh
│ │ ├── host/
│ │ │ └── host.h
│ │ ├── kerutils.cuh
│ │ └── supplemental/
│ │ └── torch_tensors.h
│ ├── params.h
│ ├── sm100/
│ │ ├── decode/
│ │ │ ├── head128/
│ │ │ │ └── README.md
│ │ │ └── head64/
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── model1.cu
│ │ │ │ └── v32.cu
│ │ │ ├── kernel.cuh
│ │ │ └── kernel.h
│ │ ├── helpers.h
│ │ └── prefill/
│ │ ├── dense/
│ │ │ ├── collective/
│ │ │ │ ├── fmha_common.hpp
│ │ │ │ ├── fmha_fusion.hpp
│ │ │ │ ├── sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_load_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
│ │ │ │ └── sm100_fmha_mla_load_tma_warpspecialized.hpp
│ │ │ ├── common/
│ │ │ │ ├── gather_tensor.hpp
│ │ │ │ ├── helper.h
│ │ │ │ ├── mask.cuh
│ │ │ │ ├── pipeline_mla.hpp
│ │ │ │ ├── pow_2.hpp
│ │ │ │ └── utils.hpp
│ │ │ ├── device/
│ │ │ │ ├── fmha.hpp
│ │ │ │ └── fmha_device_bwd.hpp
│ │ │ ├── fmha_cutlass_bwd_sm100.cu
│ │ │ ├── fmha_cutlass_bwd_sm100.cuh
│ │ │ ├── fmha_cutlass_fwd_sm100.cu
│ │ │ ├── fmha_cutlass_fwd_sm100.cuh
│ │ │ ├── interface.h
│ │ │ └── kernel/
│ │ │ ├── fmha_causal_tile_scheduler.hpp
│ │ │ ├── fmha_kernel_bwd_convert.hpp
│ │ │ ├── fmha_kernel_bwd_sum_OdO.hpp
│ │ │ ├── fmha_options.hpp
│ │ │ ├── fmha_tile_scheduler.hpp
│ │ │ ├── sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
│ │ │ ├── sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
│ │ │ └── sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
│ │ └── sparse/
│ │ ├── common_subroutine.h
│ │ ├── fwd/
│ │ │ ├── head128/
│ │ │ │ ├── config.h
│ │ │ │ ├── instantiations/
│ │ │ │ │ ├── phase1_k512.cu
│ │ │ │ │ └── phase1_k576.cu
│ │ │ │ ├── phase1.cuh
│ │ │ │ └── phase1.h
│ │ │ └── head64/
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── phase1_k512.cu
│ │ │ │ └── phase1_k576.cu
│ │ │ ├── phase1.cuh
│ │ │ └── phase1.h
│ │ └── fwd_for_small_topk/
│ │ └── head128/
│ │ ├── config.h
│ │ ├── instantiations/
│ │ │ ├── phase1_decode_k512.cu
│ │ │ └── phase1_prefill_k512.cu
│ │ ├── phase1.cuh
│ │ └── phase1.h
│ ├── sm90/
│ │ ├── decode/
│ │ │ ├── dense/
│ │ │ │ ├── config.h
│ │ │ │ ├── instantiations/
│ │ │ │ │ ├── bf16.cu
│ │ │ │ │ └── fp16.cu
│ │ │ │ ├── splitkv_mla.cuh
│ │ │ │ ├── splitkv_mla.h
│ │ │ │ └── traits.h
│ │ │ └── sparse_fp8/
│ │ │ ├── components/
│ │ │ │ ├── config.h
│ │ │ │ ├── dequant.h
│ │ │ │ └── helpers.h
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── model1_persistent_h128.cu
│ │ │ │ ├── model1_persistent_h64.cu
│ │ │ │ ├── v32_persistent_h128.cu
│ │ │ │ └── v32_persistent_h64.cu
│ │ │ ├── splitkv_mla.cuh
│ │ │ └── splitkv_mla.h
│ │ ├── helpers.h
│ │ └── prefill/
│ │ └── sparse/
│ │ ├── config.h
│ │ ├── fwd.cu
│ │ ├── fwd.h
│ │ ├── instantiations/
│ │ │ ├── phase1_k512.cu
│ │ │ ├── phase1_k512_topklen.cu
│ │ │ ├── phase1_k576.cu
│ │ │ └── phase1_k576_topklen.cu
│ │ ├── phase1.cuh
│ │ └── phase1.h
│ ├── smxx/
│ │ └── decode/
│ │ ├── combine/
│ │ │ ├── combine.cu
│ │ │ └── combine.h
│ │ └── get_decoding_sched_meta/
│ │ ├── get_decoding_sched_meta.cu
│ │ └── get_decoding_sched_meta.h
│ └── utils.h
├── docs/
│ ├── 20250422-new-kernel-deep-dive.md
│ └── 20250929-hopper-fp8-sparse-deep-dive.md
├── flash_mla/
│ ├── __init__.py
│ └── flash_mla_interface.py
├── setup.py
└── tests/
├── kernelkit/
│ ├── .gitignore
│ ├── __init__.py
│ ├── bench.py
│ ├── compare.py
│ ├── generate.py
│ ├── precision.py
│ └── utils.py
├── lib.py
├── quant.py
├── ref.py
├── test_flash_mla_dense_decoding.py
├── test_flash_mla_sparse_decoding.py
├── test_flash_mla_sparse_prefill.py
└── test_fmha_sm100.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
build
*.so
*.egg-info/
__pycache__/
dist/
*perf.csv
*.png
/.vscode
compile_commands.json
.cache
/dev
/.clangd
================================================
FILE: .gitmodules
================================================
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# FlashMLA
## Introduction
FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations:
**Sparse Attention Kernels**
*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).*
- Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache
**Dense Attention Kernels**
- Dense attention for the prefill stage
- Dense attention for the decoding stage
## News
- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md).
- **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100!
- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).
- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀
## Performance
#### Test & benchmark MLA decoding (Sparse & Dense):
```bash
python tests/test_flash_mla_dense_decoding.py
python tests/test_flash_mla_sparse_decoding.py
```
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
#### Test & benchmark MHA prefill (Dense):
```bash
python tests/test_fmha_sm100.py
```
It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.
#### Test & benchmark MLA prefill (Sparse):
```bash
python tests/test_flash_mla_sparse_prefill.py
```
It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
## Requirements
- SM90 / SM100 (See the support matrix below)
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
- PyTorch 2.0 and above
Support matrix:
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: |
| Dense Decoding | SM90 | MQA | BF16 |
| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] |
| Dense Prefill | SM100 | MHA | |
| Sparse Prefill | SM90 & SM100 | MQA | |
[1]: For more details on using FP8 KV cache, see documents below.
[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).
## Installation
```bash
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```
## Usage
### MLA Decoding
To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:
```python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens,
s_q * h_q // h_kv,
h_kv,
h_q,
is_fp8,
topk,
)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits,
is_causal, is_fp8_kvcache, indices,
)
...
```
Where
- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.
- `h_kv` is the number of key-value heads.
- `h_q` is the number of query heads.
**FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.
See `tests/quant.py` for quantization and dequantization details.
**Sparse Attention (`indices` tensor):**
The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.
- **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`.
- **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter.
- **Invalid entries:** Set invalid indices to `-1`.
**Return Values:**
The kernel returns `(out, lse)`, where:
- `out` is the attention result.
- `lse` is the log-sum-exp value of the attention scores for each query head.
See `tests/test_flash_mla_decoding.py` for a complete example.
### Sparse MLA Prefill
For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters:
- `q`: Query tensor of shape `[s_q, h_q, d_qk]`
- `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]`
- `indices`: Indices tensor of shape `[s_q, h_kv, topk]`
- `sm_scale`: A scalar value
**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing.
**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`.
**Return Values and Equivalent PyTorch Code:**
The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations:
```python
Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32
kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1) # [s_q, topk]
focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].
P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse) # [s_q, h_q, topk]
out = S @ focused_kv # [s_q, h_q, d_qk]
return (out, max_logits, lse)
```
See `tests/test_flash_mla_prefill.py` for a complete example.
### Dense MHA Prefill
This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:
- `flash_attn_varlen_func`
- `flash_attn_varlen_qkvpacked_func`
- `flash_attn_varlen_kvpacked_func`
The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example.
## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
## Community Support
### MetaX
For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).
The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
### Moore Threads
For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).
The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).
### Hygon DCU
For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).
The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).
### Intellifusion
For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).
The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).
### Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).
The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
### AMD Instinct
For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html).
The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py)
## Citation
```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
```
================================================
FILE: benchmark/bench_flash_mla.py
================================================
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import math
import random
import flashinfer
import torch
import triton
import triton.language as tl
# pip install flashinfer-python
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_torch, lse_torch = ref_mla()
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla():
return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal,
)
out_flash, lse_flash = flash_mla()
t = triton.testing.do_bench(flash_mla)
return out_flash, lse_flash, t
@torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
kv_indptr = [0]
kv_indices = []
for i in range(b):
seq_len = cache_seqlens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_table[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8),
backend="fa3"
)
mla_wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
cache_seqlens,
h_q,
dv,
d-dv,
block_size,
causal,
1 / math.sqrt(d),
q.dtype,
blocked_k.dtype,
)
def flash_infer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer()
t = triton.testing.do_bench(flash_infer)
return out_flash, lse_flash, t
@triton.jit
def _mla_attn_kernel(
Q_nope,
Q_pe,
Kv_c_cache,
K_pe_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
stride_q_nope_bs,
stride_q_nope_h,
stride_q_pe_bs,
stride_q_pe_h,
stride_kv_c_bs,
stride_k_pe_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
HEAD_DIM_KPE: tl.constexpr,
):
cur_batch = tl.program_id(1)
cur_head_id = tl.program_id(0)
split_kv_id = tl.program_id(2)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
q_pe = tl.load(Q_pe + offs_q_pe)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
qk *= sm_scale
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
v_c = tl.trans(k_c)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v_c.dtype), v_c)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
def _mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
):
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]
BLOCK_H = 16
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
num_kv_splits,
)
_mla_attn_kernel[grid](
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
req_to_tokens,
b_seq_len,
attn_logits,
sm_scale,
# stride
q_nope.stride(0),
q_nope.stride(1),
q_pe.stride(0),
q_pe.stride(1),
kv_c_cache.stride(-2),
k_pe_cache.stride(-2),
req_to_tokens.stride(0),
attn_logits.stride(0),
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
B_seq_len,
O,
stride_l_b,
stride_l_h,
stride_l_s,
stride_o_b,
stride_o_h,
NUM_KV_SPLITS: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
n_e_max = tl.maximum(logits_1, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(logits_1 - n_e_max)
acc += exp_logic * logits
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
)
def _mla_softmax_reducev(
logits,
o,
b_seq_len,
num_kv_splits,
):
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
grid = (batch_size, head_num)
_mla_softmax_reducev_kernel[grid](
logits,
b_seq_len,
o,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
HEAD_DIM_CKV=head_dim_ckv,
num_warps=4,
num_stages=2,
)
def mla_decode_triton(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
o,
req_to_tokens,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
):
assert num_kv_splits == attn_logits.shape[2]
_mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
)
_mla_softmax_reducev(
attn_logits,
o,
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"flash_mla": run_flash_mla,
"flash_infer": run_flash_infer,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert baseline in FUNC_TABLE
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]:
# flash_infer has a different lse return value
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert target in FUNC_TABLE
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_b
available_targets = [
"torch",
"flash_mla",
"flash_infer",
"flash_mla_triton",
]
shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", type=str, default="torch")
parser.add_argument("--target", type=str, default="flash_mla")
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
================================================
FILE: benchmark/visualize.py
================================================
import argparse
import matplotlib.pyplot as plt
import pandas as pd
def parse_args():
parser = argparse.ArgumentParser(description='Visualize benchmark results')
parser.add_argument('--file', type=str, default='all_perf.csv',
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
return parser.parse_args()
args = parse_args()
file_path = args.file
df = pd.read_csv(file_path)
names = df['name'].unique()
for name in names:
subset = df[df['name'] == name]
plt.plot(subset['seqlen'], subset['bw'], label=name)
plt.title('bandwidth')
plt.xlabel('seqlen')
plt.ylabel('bw (GB/s)')
plt.legend()
plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')
================================================
FILE: csrc/api/api.cpp
================================================
#include <pybind11/pybind11.h>
#include "sparse_fwd.h"
#include "sparse_decode.h"
#include "dense_decode.h"
#include "dense_fwd.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
}
================================================
FILE: csrc/api/common.h
================================================
#pragma once
#include <span>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <kerutils/supplemental/torch_tensors.h>
#include <cutlass/bfloat16.h>
static constexpr float LOG_2_E = 1.44269504f;
// Instantiation for tensor.data_ptr<cutlass::bfloat16_t>()
template<>
inline cutlass::bfloat16_t* at::TensorBase::data_ptr<cutlass::bfloat16_t>() const {
return reinterpret_cast<cutlass::bfloat16_t*>(this->data_ptr());
}
// A struct that holds the architecture information of the current GPU.
struct Arch {
int major;
int minor;
int num_sms;
cudaDeviceProp* device_prop;
Arch() {
device_prop = at::cuda::getCurrentDeviceProperties();
major = device_prop->major;
minor = device_prop->minor;
num_sms = device_prop->multiProcessorCount;
}
bool is_sm90a() const {
return major == 9 && minor == 0;
}
bool is_sm100f() const {
return major == 10;
}
};
// Convert int64_t stride to int32_t, with overflow check.
inline int int64_stride_to_int(int64_t orig_stride) {
if (orig_stride > std::numeric_limits<int>::max()) {
TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride);
}
return static_cast<int>(orig_stride);
}
#define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \
[&] () { \
if (NUM_HEADS == 128) { \
static constexpr int CONSTEXPR_NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \
} ();
#define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \
[&] () { \
if (HEAD_DIM == 576) { \
static constexpr int CONSTEXPR_NAME = 576; \
return __VA_ARGS__(); \
} else if (HEAD_DIM == 512) { \
static constexpr int CONSTEXPR_NAME = 512; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \
} \
} ();
#define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \
[&] () { \
if (FLAG) { \
static constexpr bool CONSTEXPR_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONSTEXPR_NAME = false; \
return __VA_ARGS__(); \
} \
} ();
#define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \
[&] () { \
if (MODEL_TYPE == ModelType::V32) { \
static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \
return __VA_ARGS__(); \
} else if (MODEL_TYPE == ModelType::MODEL1) { \
static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \
} \
} ();
// The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names.
template<auto value>
constexpr auto get_static_enum_name(){
std::string_view name;
#if __GNUC__ || __clang__
name = __PRETTY_FUNCTION__;
std::size_t start = name.find('=') + 2;
std::size_t end = name.size() - 1;
name = std::string_view{ name.data() + start, end - start };
start = name.find("::");
#elif _MSC_VER
name = __FUNCSIG__;
std::size_t start = name.find('<') + 1;
std::size_t end = name.rfind(">(");
name = std::string_view{ name.data() + start, end - start };
start = name.rfind("::");
#endif
return start == std::string_view::npos ? name : std::string_view {
name.data() + start + 2, name.size() - start - 2
};
}
template<typename T, std::size_t N = 0>
static constexpr std::size_t get_enum_max(){
constexpr T value = static_cast<T>(N);
if constexpr (get_static_enum_name<value>().find(")") == std::string_view::npos)
return get_enum_max<T, N + 1>();
else
return N;
}
template<typename T> requires std::is_enum_v<T>
static constexpr std::string get_dynamic_enum_name(T value){
constexpr std::size_t num = get_enum_max<T>();
constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){
return std::array<std::string_view, num>{
get_static_enum_name<static_cast<T>(Is)>()...
};
}(std::make_index_sequence<num>{});
return (std::string)names[static_cast<std::size_t>(value)];
}
// A shortcut macro to declare supported features in an implementation class.
#define DECLARE_SUPPORTED_FEATURES(...) \
protected: \
static constexpr FeatureT features[] = { __VA_ARGS__ }; \
constexpr inline std::span<const FeatureT> get_supported_features() const override { \
return features; \
}
/*
ImplBase - The base class for every implementation.
Every implementation should inherit from this class and implement the pure virtual functions, including:
- `run_`: The function that runs the implementation.
- `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way.
The dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`.
*/
template<
typename RunArgT_,
typename FeatureT_
>
class ImplBase {
protected:
using RunArgT = RunArgT_;
using FeatureT = FeatureT_;
virtual inline void run_(const RunArgT ¶ms, const std::vector<FeatureT> &required_features) = 0;
constexpr virtual inline std::span<const FeatureT> get_supported_features() const = 0;
virtual ~ImplBase() = default;
public:
inline bool check_if_all_features_are_supported(const std::vector<FeatureT> &required_features) {
for (const auto &required_feature : required_features) {
bool is_supported = false;
for (const auto &supported_feature : get_supported_features()) {
if (required_feature == supported_feature) {
is_supported = true;
break;
}
}
if (!is_supported) {
return false;
}
}
return true;
}
inline void check_if_all_features_are_supported_and_abort(const std::vector<FeatureT> &required_features) {
if (!check_if_all_features_are_supported(required_features)) {
fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n");
fprintf(stderr, "Required features:\n");
for (const auto &f : required_features) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(f), get_dynamic_enum_name(f).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "Supported features:\n");
for (const auto &supported_feature : get_supported_features()) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(supported_feature), get_dynamic_enum_name(supported_feature).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "Features that are required but not supported:\n");
for (const auto &required_feature : required_features) {
bool is_supported = false;
for (const auto &supported_feature : get_supported_features()) {
if (required_feature == supported_feature) {
is_supported = true;
break;
}
}
if (!is_supported) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(required_feature), get_dynamic_enum_name(required_feature).c_str());
}
}
fprintf(stderr, "\n");
Arch cur_gpu_arch = Arch();
fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms);
fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n");
TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details.");
}
}
inline void run(const RunArgT ¶ms, const std::vector<FeatureT> &required_features) {
check_if_all_features_are_supported_and_abort(required_features);
run_(params, required_features);
}
};
================================================
FILE: csrc/api/dense_decode.h
================================================
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
#include "sm90/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits // batch_size + 1
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
================================================
FILE: csrc/api/dense_fwd.h
================================================
#pragma once
#include "common.h"
#include "sm100/prefill/dense/interface.h"
================================================
FILE: csrc/api/sparse_decode.h
================================================
#pragma once
#include "common.h"
#include "params.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/decode/head64/kernel.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
// Feature set of sparse decoding kernels
enum class DecodeFeatures : int {
HEAD_64,
HEAD_128,
HEAD_DIM_576,
HEAD_DIM_512,
V32_KVCACHE_FORMAT,
MODEL1_KVCACHE_FORMAT,
ATTN_SINK,
TOPK_LENGTH,
EXTRA_KVCACHE,
EXTRA_TOPK_LENGTH
};
struct DecodeImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int block_size_topk;
};
class DecodeImplBase : public ImplBase<
SparseAttnDecodeParams,
DecodeFeatures
> {
public:
virtual DecodeImplMeta get_meta(int h_q, int s_q) = 0;
};
class Decode_Sm90_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q / (h_q/64), 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() {
sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);
});
});
}
};
class Decode_Sm100_Head64_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(params);
});
}
};
// An implementation that calls the head64 kernel twice to process head128
// Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f
class Decode_Sm100_Head64x2_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) {
SparseAttnDecodeParams cur_params = params;
cur_params.q += start_head_idx * params.stride_q_h_q;
if (cur_params.attn_sink) {
cur_params.attn_sink += start_head_idx;
}
cur_params.lse += start_head_idx;
cur_params.out += start_head_idx * params.stride_o_h_q;
cur_params.lse_accum += start_head_idx;
cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q;
cur_params.h_q = 64;
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(cur_params);
}
});
}
};
class Decode_Sm100_Head128_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q / 2, 1),
3,
64
};
}
protected:
void run_(const SparseAttnDecodeParams ¶ms, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(params);
}
};
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
sparse_attn_decode_interface(
const at::Tensor &q, // [b, s_q, h_q, d_qk]
const at::Tensor &kv, // [num_blocks, page_block_size, h_k, d_qk]
const at::Tensor &indices, // [b, s_q, topk]
const std::optional<at::Tensor> &topk_length, // [b, s_q]
const std::optional<at::Tensor> &attn_sink, // [h_q]
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits, // batch_size + 1
const std::optional<at::Tensor> &extra_kv,
const std::optional<at::Tensor> &extra_indices,
const std::optional<at::Tensor> &extra_topk_length,
int d_v,
float sm_scale
) {
using bf16 = cutlass::bfloat16_t;
// Check the architecture
Arch arch = Arch();
KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4);
KU_CHECK_NDIM(indices, 3);
int b = q.size(0);
int s_q = q.size(1);
int h_q = q.size(2);
int d_qk = q.size(3);
int num_blocks = kv.size(0);
int page_block_size = kv.size(1);
int h_kv = kv.size(2);
int topk = indices.size(2);
bool have_topk_length = topk_length.has_value();
bool have_extra_kcache = extra_kv.has_value();
bool have_extra_topk_length = extra_topk_length.has_value();
bool have_attn_sink = attn_sink.has_value();
int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0;
if (have_extra_kcache) {
extra_num_blocks = extra_kv->size(0);
extra_page_block_size = extra_kv->size(1);
}
if (extra_indices.has_value()) {
extra_topk = extra_indices->size(-1);
}
// metadata sanity check
TORCH_CHECK(b > 0);
TORCH_CHECK(s_q > 0);
TORCH_CHECK(h_q > 0);
TORCH_CHECK(h_kv == 1, "Currently only MQA (i.e. h_kv == 1) is supported for sparse decoding");
TORCH_CHECK(d_qk == 576 || d_qk == 512, "Only head_size_k == 576 or 512 is supported for sparse decoding");
TORCH_CHECK(d_v == 512, "Only head_size_v == 512 is supported for sparse decoding");
TORCH_CHECK(topk > 0);
if (have_extra_kcache) {
TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_kcache is provided for sparse attention");
} else {
TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided");
TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided");
}
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(extra_kv);
KU_CHECK_DEVICE(extra_indices);
KU_CHECK_DEVICE(extra_topk_length);
// Check data type
KU_CHECK_DTYPE(q, torch::kBFloat16);
TORCH_CHECK(kv.dtype() == torch::kFloat8_e4m3fn || kv.dtype() == torch::kInt8 || kv.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn, int8 or uint8");
if (extra_kv.has_value()) {
TORCH_CHECK(extra_kv->dtype() == torch::kFloat8_e4m3fn || extra_kv->dtype() == torch::kInt8 || extra_kv->dtype() == torch::kUInt8, "extra k cache must have dtype fp8_e4m3fn, int8 or uint8");
}
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DTYPE(extra_indices, torch::kInt32);
KU_CHECK_DTYPE(extra_topk_length, torch::kInt32);
// Check layout
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_CONTIGUOUS(topk_length);
KU_CHECK_CONTIGUOUS(attn_sink);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices);
KU_CHECK_CONTIGUOUS(extra_topk_length);
// Check shape
KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk);
{
int bytes_per_token;
if (d_qk == 576 && d_v == 512) {
// V3.2 style
bytes_per_token = 512 + 64*2 + (512/128)*4;
} else if (d_qk == 512 && d_v == 512) {
// MODEL1 style
bytes_per_token = 448 + 64*2 + (448/64)*1 + 1;
} else {
TORCH_CHECK(false, "Unsupported head sizes for is_fp8_kvcache == True");
}
KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token);
KU_CHECK_SHAPE(extra_kv, extra_num_blocks, extra_page_block_size, h_kv, bytes_per_token);
TORCH_CHECK(kv.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for kv cache");
if (extra_kv.has_value()) {
TORCH_CHECK(extra_kv->stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for extra kv cache");
}
}
KU_CHECK_SHAPE(indices, b, s_q, topk);
KU_CHECK_SHAPE(topk_length, b);
KU_CHECK_SHAPE(attn_sink, h_q);
KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk);
KU_CHECK_SHAPE(extra_topk_length, b);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({b, s_q, h_q}, opts.dtype(at::kFloat));
ModelType model_type;
if (d_qk == 576) {
model_type = ModelType::V32;
} else if (d_qk == 512) {
model_type = ModelType::MODEL1;
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
std::vector<DecodeFeatures> features;
if (h_q == 64) {
features.push_back(DecodeFeatures::HEAD_64);
} else if (h_q == 128) {
features.push_back(DecodeFeatures::HEAD_128);
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
if (d_qk == 576) {
features.push_back(DecodeFeatures::HEAD_DIM_576);
} else if (d_qk == 512) {
features.push_back(DecodeFeatures::HEAD_DIM_512);
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
if (model_type == ModelType::V32) {
features.push_back(DecodeFeatures::V32_KVCACHE_FORMAT);
} else if (model_type == ModelType::MODEL1) {
features.push_back(DecodeFeatures::MODEL1_KVCACHE_FORMAT);
} else {
TORCH_CHECK(false, "Unsupported model type: ", (int)model_type);
}
if (have_attn_sink) {
features.push_back(DecodeFeatures::ATTN_SINK);
}
if (have_topk_length) {
features.push_back(DecodeFeatures::TOPK_LENGTH);
}
if (have_extra_kcache) {
features.push_back(DecodeFeatures::EXTRA_KVCACHE);
}
if (have_extra_topk_length) {
features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH);
}
DecodeImplBase* impl;
if (arch.is_sm100f()) {
if (h_q == 64) {
impl = new Decode_Sm100_Head64_Impl();
} else if (h_q == 128) {
if (d_qk == 576) {
impl = new Decode_Sm100_Head64x2_Impl();
} else if (d_qk == 512) {
impl = new Decode_Sm100_Head128_Impl();
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
} else if (arch.is_sm90a()) {
impl = new Decode_Sm90_Impl();
} else {
TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd");
}
DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q);
SparseAttnDecodeParams params = {
b, s_q, h_q, h_kv, d_qk, d_v,
sm_scale, sm_scale * LOG_2_E,
num_blocks, page_block_size, topk,
model_type,
(bf16*)q.data_ptr(),
(bf16*)kv.data_ptr(),
(int*)indices.data_ptr(),
ku::get_optional_tensor_ptr<int>(topk_length),
ku::get_optional_tensor_ptr<float>(attn_sink),
(float*)lse.data_ptr(),
(bf16*)out.data_ptr(),
extra_num_blocks, extra_page_block_size, extra_topk,
ku::get_optional_tensor_ptr<bf16>(extra_kv),
ku::get_optional_tensor_ptr<int>(extra_indices),
ku::get_optional_tensor_ptr<int>(extra_topk_length),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(q.stride(2)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
int64_stride_to_int(lse.stride(0)), int64_stride_to_int(lse.stride(1)),
int64_stride_to_int(out.stride(0)), int64_stride_to_int(out.stride(1)), int64_stride_to_int(out.stride(2)),
have_extra_kcache ? int64_stride_to_int(extra_kv->stride(0)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_kv->stride(1)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_indices->stride(0)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_indices->stride(1)) : 0,
at::cuda::getCurrentCUDAStream().stream()
};
// Get MLA metadata if necessary
at::Tensor o_accum, lse_accum;
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({b+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
b, s_q,
impl_meta.block_size_topk,
impl_meta.fixed_overhead_num_blocks,
topk,
extra_topk,
ku::get_optional_tensor_ptr<int>(topk_length),
ku::get_optional_tensor_ptr<int>(extra_topk_length),
nullptr,
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
}
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, b+1);
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_splits_ptr = num_splits->data_ptr<int>();
params.num_sm_parts = impl_meta.num_sm_parts;
// Allocate intermediate buffers for split-KV
const int total_num_splits = b + impl_meta.num_sm_parts;
lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat));
o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(o_accum);
params.lse_accum = lse_accum.data_ptr<float>();
params.o_accum = o_accum.data_ptr<float>();
params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0));
params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1));
params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0));
params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1));
params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2));
impl->run(params, features);
CombineParams combine_params = {
b, s_q, h_q, d_v,
params.lse,
params.out,
params.stride_lse_b, params.stride_lse_s_q,
params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q,
params.lse_accum,
params.o_accum,
params.stride_lse_accum_split, params.stride_lse_accum_s_q,
params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
ku::get_optional_tensor_ptr<float>(attn_sink),
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
delete impl;
return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits};
}
================================================
FILE: csrc/api/sparse_fwd.h
================================================
#pragma once
#include "common.h"
#include "params.h"
#include "sm90/prefill/sparse/phase1.h"
#include "sm100/prefill/sparse/fwd/head128/phase1.h"
#include "sm100/prefill/sparse/fwd/head64/phase1.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
enum class FwdFeatures : int {
HEAD_64,
HEAD_128,
HEAD_DIM_576,
HEAD_DIM_512,
ATTN_SINK,
SINK_LSE,
TOPK_LENGTH
};
class FwdImplBase : public ImplBase<
SparseAttnFwdParams,
FwdFeatures
> {};
class Fwd_Sm90_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
});
});
}
};
class Fwd_Sm100_Head64_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams ¶ms, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams ¶ms, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
}
};
static std::vector<at::Tensor> sparse_attn_prefill_interface(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
float sm_scale,
int d_v,
const std::optional<at::Tensor> &attn_sink,
const std::optional<at::Tensor> &topk_length
) {
using bf16 = cutlass::bfloat16_t;
Arch arch = Arch();
bool is_sm90a = arch.is_sm90a();
bool is_sm100f = arch.is_sm100f();
TORCH_CHECK(is_sm90a || is_sm100f, "Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures.");
KU_CHECK_NDIM(q, 3);
KU_CHECK_NDIM(kv, 3);
KU_CHECK_NDIM(indices, 3);
KU_CHECK_NDIM(attn_sink, 1);
KU_CHECK_NDIM(topk_length, 1);
int s_q = q.size(0);
int s_kv = kv.size(0);
int h_q = q.size(1);
int h_kv = kv.size(1);
int d_qk = q.size(2);
int topk = indices.size(2);
bool have_topk_length = topk_length.has_value();
TORCH_CHECK(d_qk == 576 || d_qk == 512, "Invalid d_qk: ", d_qk);
TORCH_CHECK(d_v == 512, "Invalid d_v", d_v);
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DTYPE(q, torch::kBFloat16);
KU_CHECK_DTYPE(kv, torch::kBFloat16);
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_SHAPE(q, s_q, h_q, d_qk);
KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
KU_CHECK_SHAPE(indices, s_q, h_kv, topk);
KU_CHECK_SHAPE(attn_sink, h_q);
KU_CHECK_SHAPE(topk_length, s_q);
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink);
KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length);
// Allocate results and buffers
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
KU_CHECK_CONTIGUOUS(max_logits);
SparseAttnFwdParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * LOG_2_E,
(bf16*)q.data_ptr(),
(bf16*)kv.data_ptr(),
(int*)indices.data_ptr(),
ku::get_optional_tensor_ptr<float>(attn_sink),
ku::get_optional_tensor_ptr<int>(topk_length),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
(bf16*)out.data_ptr(),
(float*)max_logits.data_ptr(),
(float*)lse.data_ptr(),
arch.num_sms,
at::cuda::getCurrentCUDAStream().stream()
};
std::vector<FwdFeatures> required_features;
if (h_q == 64) {
required_features.push_back(FwdFeatures::HEAD_64);
} else if (h_q == 128) {
required_features.push_back(FwdFeatures::HEAD_128);
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
if (d_qk == 576) {
required_features.push_back(FwdFeatures::HEAD_DIM_576);
} else if (d_qk == 512) {
required_features.push_back(FwdFeatures::HEAD_DIM_512);
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
if (attn_sink.has_value()) {
required_features.push_back(FwdFeatures::ATTN_SINK);
}
if (have_topk_length) {
required_features.push_back(FwdFeatures::TOPK_LENGTH);
}
if (is_sm90a) {
Fwd_Sm90_Impl fwd_impl;
fwd_impl.run(params, required_features);
} else if (is_sm100f) {
if (h_q == 64) {
Fwd_Sm100_Head64_Impl fwd_impl;
fwd_impl.run(params, required_features);
} else if (h_q == 128) {
Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
Fwd_Sm100_Head128_Impl regular_impl;
bool use_small_topk_impl = false;
if (
(topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
!regular_impl.check_if_all_features_are_supported(required_features)
) {
use_small_topk_impl = true;
}
if (use_small_topk_impl) {
small_topk_impl.run(params, required_features);
} else {
regular_impl.run(params, required_features);
}
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
} else {
TORCH_CHECK(false, "Unsupported architecture");
}
return {out, max_logits, lse};
}
================================================
FILE: csrc/defines.h
================================================
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7;
};
struct float8 {
float2 a01, a23, a45, a67;
};
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
================================================
FILE: csrc/kerutils/include/kerutils/common/common.h
================================================
#pragma once
namespace kerutils {}
#define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print("\n"); }
namespace ku = kerutils;
================================================
FILE: csrc/kerutils/include/kerutils/device/common.h
================================================
/*
Common data types and macros that are used across the kerutils library.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
#include <cute/config.hpp> // For CUTE_DEVICE
namespace kerutils {
// Cache hints
enum class CacheHint {
EVICT_FIRST,
EVICT_NORMAL,
EVICT_LAST,
EVICT_UNCHANGED,
NO_ALLOCATE
};
// Prefetch size
enum class PrefetchSize {
B64,
B128,
B256
};
using nvbf16 = __nv_bfloat16;
using nvbf16x2 = __nv_bfloat162;
using nve4m3 = __nv_fp8_e4m3;
using nve4m3x2 = __nv_fp8x2_e4m3;
using nve4m3x4 = __nv_fp8x4_e4m3;
using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define KERUTILS_ENABLE_SM80
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
static_assert(false, "kerutils doesn't support SM architectures below SM80");
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define KERUTILS_ENABLE_SM90
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000))
#define KERUTILS_ENABLE_SM90A
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
#define KERUTILS_ENABLE_SM100
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200))
#define KERUTILS_ENABLE_SM100A
#endif
#if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
#define KERUTILS_ENABLE_SM80
#define KERUTILS_ENABLE_SM90
#define KERUTILS_ENABLE_SM90A
#define KERUTILS_ENABLE_SM100
#define KERUTILS_ENABLE_SM100A
#endif
================================================
FILE: csrc/kerutils/include/kerutils/device/device.cuh
================================================
#pragma once
#include "kerutils/common/common.h"
#include "common.h"
#include "sm80/intrinsics.cuh"
#include "sm80/helpers.cuh"
#include "sm90/intrinsics.cuh"
#include "sm90/helpers.cuh"
#include "sm100/intrinsics.cuh"
#include "sm100/helpers.cuh"
#include "sm100/gemm.cuh"
#include "sm100/tma_cta_group2_nosplit.cuh"
================================================
FILE: csrc/kerutils/include/kerutils/device/sm100/gemm.cuh
================================================
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/device/common.h>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
// Besides, CuTe's UTCMMA has an `elect_one_sync()` inside which is really disgusting, so we have our own variant without `elect_one_sync()` here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_TS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>; // Actually this should be "duplicated", however, our great CuTe doesn't allow us to set it to "duplicated", so we just set it to NonInterleaved for a correct address calculation
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256 bits; transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint32_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 16 between 16 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 16 between 16 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_TS_NOELECT
{
static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) ||
(M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),
"SM100_MMA_F16BF16_TS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\
or a multiple of 16 between 16 and 256 for M=128.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
uint32_t mask[4] = {0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;
// Logical shape-K is always 256 bits; transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint32_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_SS_NOELECT
{
static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) ||
(M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),
"SM100_MMA_F16BF16_SS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\
or a multiple of 16 between 16 and 256 for M=128.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
uint32_t mask[4] = {0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm100/helpers.cuh
================================================
#pragma once
#include <cute/tensor.hpp>
#include "kerutils/device/common.h"
namespace kerutils {
// Perform SS UTCMMA
// sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ss(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
// Perform TS UTCMMA
// sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_k_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_K_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_K_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_K_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_K_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_1, _2>{}
), Shape<_1, _1>{});
}
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_mn_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_MN_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_MN_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_MN_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_MN_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_2, _1>{}
), Shape<_1, _1>{});
}
template<cute::UMMA::Major MAJOR, int MN, int K, int SWIZZLE, typename T = bf16>
auto make_umma_canonical_layout() {
if constexpr (MAJOR == cute::UMMA::Major::K) {
return make_umma_canonical_k_major_layout<MN, K, SWIZZLE, T>();
} else {
return make_umma_canonical_mn_major_layout<MN, K, SWIZZLE, T>();
}
}
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh
================================================
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
asm volatile(
"cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\n"
:
: "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"l"(cache_hint)
);
}
// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= cute::Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add)
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"add.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b))
);
return c;
}
// Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul)
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"mul.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)));
return c;
}
// Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma)
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 d;
asm volatile(
"fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
return d;
}
// Vectorized negation for foat32
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
// st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)
CUTE_DEVICE
void st_bulk(void* dst_ptr, int64_t size) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
asm volatile (
"st.bulk.weak.shared::cta [%0], %1, 0;\n"
:
: "r"(dst_addr), "l"(size)
: "memory"
);
}
struct CUTE_ALIGNAS(16) CLCResponseObj {
// An opaque 16B value
char opaque[16];
};
struct CLCResult {
int is_valid;
int x, y, z;
};
// Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)
// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions
template<bool USE_LD_ACQUIRE>
CUTE_DEVICE
CLCResult get_clc_query_response(CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj);
CLCResult result;
#define EMIT_ASM(LD_MODIFIER) \
asm volatile( \
"{\n" \
".reg .pred p1;\n\t" \
".reg .b128 clc_result;\n\t" \
"ld" LD_MODIFIER ".shared.b128 clc_result, [%4];\n\t" \
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" \
"selp.u32 %3, 1, 0, p1;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\n\t" \
"}\n" \
: "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.is_valid) \
: "r"(response_addr) \
: "memory" \
);
if constexpr (USE_LD_ACQUIRE) {
EMIT_ASM(".acquire.cta");
} else {
EMIT_ASM("");
}
return result;
}
// LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
#define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v4.u64 {%0, %1, %2, %3}, [%4];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]), \
"=l"(result_as_uint64_ptr[2]), "=l"(result_as_uint64_ptr[3]) \
: "l"(global_addr) \
); \
}
// STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st)
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
#define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(src)> || std::is_array_v<decltype(src)>, "`src` must be a pointer"); \
uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \
asm volatile( \
"st.global.L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".v4.u64 [%0], {%1, %2, %3, %4};\n" \
: \
: "l"(global_addr), "l"(src_as_uint64_ptr[0]), "l"(src_as_uint64_ptr[1]), \
"l"(src_as_uint64_ptr[2]), "l"(src_as_uint64_ptr[3]) \
); \
}
}
namespace kerutils {
// tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_2x1SM_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
// tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumElements>
__device__ __forceinline__
void tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
// NOTE The following code crashes VSCode intellisense engine, so we disable it
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 64) {
cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*2>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*4>{});
#endif
}
// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
template <int kNumElements>
__device__ __forceinline__
void tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) {
uint32_t const* data = (uint32_t const*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh
================================================
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/device/common.h>
namespace cute {
// Extensions to CuTe
// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100.
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0)
{
return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
};
struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};
// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm80/helpers.cuh
================================================
#pragma once
#include "kerutils/device/common.h"
#include "kerutils/device/sm80/intrinsics.cuh"
namespace kerutils {
// Retrieve the value of `%smid` and check its range
CUTE_DEVICE
uint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) {
uint32_t sm_id = get_sm_id();
if (!(sm_id < num_physical_sms)) {
trap();
}
return sm_id;
}
#ifndef KU_TRAP_ONLY_DEVICE_ASSERT
#define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
// Construct a `float2` from a single `float` by duplicating the value
CUTE_DEVICE
float2 float2float2(const float &x) {
return float2 {x, x};
}
CUTE_DEVICE
void st_shared(void* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
CUTE_DEVICE
void st_shared(void* ptr, float4 val) {
st_shared(ptr, *(__int128_t*)&val);
}
CUTE_DEVICE
__int128_t ld_shared(void* ptr) {
__int128_t val;
asm volatile("ld.shared.b128 %0, [%1];" : "=q"(val) : "l"(__cvta_generic_to_shared(ptr)));
return val;
}
CUTE_DEVICE
float4 ld_shared_float4(void* ptr) {
__int128_t temp = ld_shared(ptr);
return *(float4*)&temp;
}
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh
================================================
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async)
template<PrefetchSize PREFETCH_SIZE=PrefetchSize::B128>
CUTE_DEVICE
void cp_async_cacheglobal(const void* src, void* dst, bool pred=true) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
if constexpr (PREFETCH_SIZE == PrefetchSize::B64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else {
static_assert(PREFETCH_SIZE == PrefetchSize::B64 ||
PREFETCH_SIZE == PrefetchSize::B128 ||
PREFETCH_SIZE == PrefetchSize::B256,
"Unsupported prefetch size for cp_async_cacheglobal.");
}
}
// Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy)
template<CacheHint PRIMARY_PRIORITY, CacheHint SECONDARY_PRIORITY>
CUTE_DEVICE
int64_t create_fraction_based_cache_policy(float fraction = 1.0f) {
int64_t result;
#define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \
asm volatile( \
"createpolicy.fractional.L2::" PRIMARY_PRIORITY_STR ".L2::" SECONDARY_PRIORITY_STR ".b64 %0, %1;\n" \
: "=l"(result) \
: "f"(fraction) \
);
#define EMIT2(PRIMARY_PRIORITY_STR) \
{ \
if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_first") \
} else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_unchanged") \
} else { \
static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \
SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \
"Unsupported secondary cache hint for create_fraction_based_cache_policy."); \
} \
}
if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) {
EMIT2("evict_first");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) {
EMIT2("evict_normal");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) {
EMIT2("evict_last");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) {
EMIT2("evict_unchanged");
} else {
static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST ||
PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL ||
PRIMARY_PRIORITY == CacheHint::EVICT_LAST ||
PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED,
"Unsupported primary cache hint for create_fraction_based_cache_policy.");
}
#undef EMIT
#undef EMIT2
return result;
}
// Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f))
// The same as cute::TMA::CacheHintSmXX
template<CacheHint CACHE_HINT>
CUTE_DEVICE
constexpr int64_t create_simple_cache_policy() {
if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) {
return 0x12F0000000000000; // Result of createpolicy.fractional.L2::evict_first.b64
} else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) {
return 0x1000000000000000; // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?)
} else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) {
return 0x14F0000000000000; // Result of createpolicy.fractional.L2::evict_last.b64
} else {
static_assert(CACHE_HINT == CacheHint::EVICT_FIRST ||
CACHE_HINT == CacheHint::EVICT_NORMAL ||
CACHE_HINT == CacheHint::EVICT_LAST,
"Unsupported cache hint for create_simple_cache_policy.");
}
}
// AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %3, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \n\t"
"}"
:
: "f"(data),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// Get the id of the current SM
// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`.
// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT]
CUTE_DEVICE
uint32_t get_sm_id() {
uint32_t ret;
asm volatile("mov.u32 %0, %%smid;\n" : "=r"(ret));
return ret;
}
// trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap)
CUTE_DEVICE
void trap() {
asm volatile("trap;\n");
}
// LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
// L2 cache hint is not supported since it's only supported for LDG.256
#define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v2.u64 {%0, %1}, [%2];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]) \
: "l"(global_addr) \
); \
}
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm90/helpers.cuh
================================================
#pragma once
#include <cute/tensor.hpp>
#include "kerutils/device/common.h"
namespace kerutils {
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE
int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE
int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx;
}
template <bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) {
using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
if constexpr (commit) {
warpgroup_commit_batch();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
}
}
================================================
FILE: csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh
================================================
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async)
template<typename T>
CUTE_DEVICE
static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
static constexpr int PEER_ADDR_MASK = 16777216;
// Given an address in the current CTA, return the corresponding address in the peer CTA
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
template<typename T>
CUTE_DEVICE
T* get_cta0_addr(const T* p) {
constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF;
return (T*)((int64_t)(p) & CTA0_ADDR_MASK);
}
// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk)
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_release() {
asm volatile("barrier.cluster.arrive.release;" : : : "memory");
}
// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_relaxed() {
asm volatile("barrier.cluster.arrive.relaxed;" : : :);
}
// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_wait_acquire() {
asm volatile("barrier.cluster.wait.acquire;" : : : "memory");
}
// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
CUTE_DEVICE
void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"{\n\t"
"mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t"
"}"
:
: "r"(smem_addr));
}
// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %6, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
"}"
:
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
CUTE_DEVICE
void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) {
uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
:
: "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr)
);
}
}
================================================
FILE: csrc/kerutils/include/kerutils/host/host.h
================================================
#pragma once
#include <exception>
#include <string>
#include <sstream>
#include <vector>
#include <cuda_runtime_api.h>
#include <cuda.h>
#include <cutlass/cuda_host_adapter.hpp>
#include "kerutils/common/common.h"
namespace kerutils {
class KUException final : public std::exception {
std::string message = {};
public:
template<typename... Args>
explicit KUException(const char *name, const char* file, const int line, Args&&... args) {
std::ostringstream oss;
oss << name << " error (" << file << ":" << line << "): ";
(oss << ... << args);
message = oss.str();
}
const char *what() const noexcept override {
return message.c_str();
}
};
#define THROW_KU_EXCEPTION(name, ...) \
throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__)
#define KU_CUDA_CHECK(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
THROW_KU_EXCEPTION("CUDA", "CUDA error: ", cudaGetErrorString(status_)); \
} \
} while(0)
#define KU_CUTLASS_CHECK(call) \
do { \
cutlass::Status status_ = call; \
if (status_ != cutlass::Status::kSuccess) { \
fprintf(stderr, "CUTLASS error (%s:%d): %d\n", __FILE__, __LINE__, static_cast<int>(status_)); \
THROW_KU_EXCEPTION("CUTLASS", "CUTLASS error: ", static_cast<int>(status_)); \
} \
} while(0)
// This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not.
#define KU_ASSERT(cond, ...) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion `%s` failed (%s:%d): ", #cond, __FILE__, __LINE__); \
if constexpr (sizeof(#__VA_ARGS__) > 1) { \
fprintf(stderr, ", " __VA_ARGS__); \
} \
fprintf(stderr, "\n"); \
THROW_KU_EXCEPTION("Assertion", "Assertion `", #cond, "` failed."); \
} \
} while(0)
#define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError())
template<typename T>
inline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) {
return (a + b - 1) / b;
}
template<typename T>
inline __host__ __device__ constexpr T ceil(const T &a, const T &b) {
return (a + b - 1) / b * b;
}
// A wrapper for make_tensor_map
static inline CUtensorMap make_tensor_map(
const std::vector<uint64_t> &size,
const std::vector<uint64_t> &strides, // PAY ATTENTION: In BYTES
const std::vector<uint32_t> &box_size,
void* global_ptr,
CUtensorMapDataType data_type,
CUtensorMapSwizzle swizzle_mode,
CUtensorMapL2promotion l2_promotion,
CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
const std::vector<uint32_t> &element_strides_ = {}
) {
int dim = size.size();
KU_ASSERT(dim >= 1);
std::vector<uint32_t> element_strides;
if (element_strides_.empty()) {
for (int i = 0; i < dim; ++i)
element_strides.push_back(1);
} else {
element_strides = element_strides_;
}
KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim);
CUtensorMap result;
CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&result,
data_type,
dim,
global_ptr,
size.data(),
strides.data(),
box_size.data(),
element_strides.data(),
interleave_mode,
swizzle_mode,
l2_promotion,
oob_fill
);
if (ret_code != CUresult::CUDA_SUCCESS) {
auto print_vector = [&](auto t, const char* fmt, const char end='\n') {
for (auto elem : t) {
printf(fmt, elem);
}
printf("%c", end);
};
fprintf(stderr, "Failed to create tensormap\n");
fprintf(stderr, "Dim: %d\n", dim);
printf("size: "); print_vector(size, "%lu ");
printf("strides: "); print_vector(strides, "%lu ");
printf("box_size: "); print_vector(box_size, "%u ");
printf("element_strides: "); print_vector(element_strides, "%u ");
printf("global ptr: 0x%lx\n", (int64_t)global_ptr);
printf("data_type: %d\n", (int)data_type);
printf("swizzle_mode: %d\n", (int)swizzle_mode);
printf("l2_promotion: %d\n", (int)l2_promotion);
printf("interleave_mode: %d\n", (int)interleave_mode);
printf("oob_fill: %d\n", (int)oob_fill);
KU_ASSERT(false);
}
return result;
}
// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size
template<typename T>
static inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) {
std::vector<uint64_t> res;
for (auto stride : strides_in_elems) {
res.push_back(((uint64_t)stride) * elem_size);
}
return res;
}
}
================================================
FILE: csrc/kerutils/include/kerutils/kerutils.cuh
================================================
#pragma once
#include "host/host.h"
#include "device/device.cuh"
================================================
FILE: csrc/kerutils/include/kerutils/supplemental/torch_tensors.h
================================================
#pragma once
#include <functional>
#include <torch/python.h>
#include "kerutils/common/common.h"
namespace kerutils {
// Check whether the given tensor or optional tensor satisfies the given condition
// If tensor_or_opt is a tensor, check_fn is applied directly
// If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value
template<typename T>
static inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function<bool(const at::Tensor&)>& check_fn) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return check_fn(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return check_fn(tensor_or_opt.value());
} else {
return true;
}
}
}
// Get the pointer of the given tensor
// Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise
template<typename PtrT>
static inline PtrT* get_tensor_ptr(const at::Tensor& tensor) {
if (tensor.has_storage()) {
return (PtrT*)tensor.data_ptr();
} else {
return nullptr;
}
}
// Get the pointer of the given tensor or optional tensor
// Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise
template<typename PtrT, typename T>
static inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return get_tensor_ptr<PtrT>(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return get_tensor_ptr<PtrT>(*tensor_or_opt);
} else {
return nullptr;
}
}
}
}
// Check whether the given tensor (or optional<tensor>) is on cuda
#define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor " must be on CUDA")
// Check whether the given tensor (or optional<tensor>) has the given number of dimensions
#define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor " must have " #ndim " dimensions")
// Check whether the given tensor (or optional<tensor>) has the given shape
#define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor " must have shape (" #__VA_ARGS__ ")")
// Check whether the given tensor (or optional<tensor>) is contiguous
#define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor " must be contiguous")
// Check whether the last dimention of the given tensor (or optional<tensor>)
#define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor " must have contiguous last dimension")
// Check whether the given tensor (or optional<tensor>) has the specified dtype
#define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor " must have dtype " #target_dtype)
================================================
FILE: csrc/params.h
================================================
#pragma once
#include "cutlass/bfloat16.h"
enum class ModelType {
V32,
MODEL1
};
struct __align__(4*8) DecodingSchedMeta {
int begin_req_idx, end_req_idx; // Both inclusive
int begin_block_idx, end_block_idx; // Inclusive, exclusive
int begin_split_idx;
int is_first_req_splitted, is_last_req_splitted;
int _pad[1];
};
static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta);
struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
using index_t = int64_t;
int b; // batch size
int s_q;
int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
int d, d_v; // K/V dimension
int h_q, h_k; // The number of Q/K heads
int num_blocks; // Number of blocks in total
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ o_ptr;
float *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
float *__restrict__ softmax_lseaccum_ptr;
float *__restrict__ oaccum_ptr;
cudaStream_t stream;
};
struct SparseAttnDecodeParams {
int b, s_q;
int h_q, h_kv;
int d_qk, d_v;
float sm_scale, sm_scale_div_log2;
int num_blocks, page_block_size, topk;
ModelType model_type;
cutlass::bfloat16_t* __restrict__ q; // [b, s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [num_blocks, page_block_size, d_qk]
int* __restrict__ indices; // [b, s_q, topk]
int* __restrict__ topk_length; // [b], may be nullptr
float* __restrict__ attn_sink; // [h_q], may be nullptr
float* __restrict__ lse; // [b, s_q, h_q]
cutlass::bfloat16_t* __restrict__ out; // [b, s_q, h_q, d_v]
int extra_num_blocks, extra_page_block_size, extra_topk;
cutlass::bfloat16_t* __restrict__ extra_kv; // [extra_num_blocks, extra_page_block_size, d_qk]
int* __restrict__ extra_indices; // [b, s_q, extra_topk]
int* __restrict__ extra_topk_length; // [b], may be nullptr
int stride_q_b, stride_q_s_q, stride_q_h_q;
int stride_kv_block, stride_kv_row;
int stride_indices_b, stride_indices_s_q;
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
int stride_extra_kv_block, stride_extra_kv_row;
int stride_extra_indices_b, stride_extra_indices_s_q;
cudaStream_t stream;
// SplitKV-related parameters
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
};
struct CombineParams {
int b, s_q, h_q, d_v;
float* __restrict__ lse; // [b, s_q, h_q]
void* __restrict__ out; // [b, s_q, h_q, d_v]
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
float* attn_sink; // [h_q], may be nullptr
cudaStream_t stream;
};
struct GetDecodeSchedMetaParams {
int b; // batch size
int s_q;
int block_size_n;
int fixed_overhead_num_blocks;
int topk, extra_topk; // -1 if sparse attention (or extra topk) is disabled
int *__restrict__ topk_length, *__restrict__ extra_topk_length;
int *__restrict__ seqlens_k_ptr; // Only necessary for dense attention
DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int num_sm_parts;
cudaStream_t stream;
};
struct SparseAttnFwdParams {
int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
float sm_scale, sm_scale_div_log2;
// Input tensors
cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk]
int* __restrict__ indices; // [s_q, h_kv, topk]
float* __restrict__ attn_sink; // [h_q], may be nullptr
int* __restrict__ topk_length; // [s_q], may be nullptr
// Strides
int stride_q_s_q; int stride_q_h_q;
int stride_kv_s_kv; int stride_kv_h_kv;
int stride_indices_s_q; int stride_indices_h_kv;
// Output tensors
cutlass::bfloat16_t* __restrict__ out; // [s_q, h_q, d_v]
float* __restrict__ max_logits; // [s_q, h_q]
float* __restrict__ lse; // [s_q, h_q]
int num_sm;
cudaStream_t stream;
};
// We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes.
enum class SparseAttnFwdMode {
Prefill, // Normal prefill mode
DecodeWithSplitKV, // To trigger decoding mode for kernels that support both prefill and decode
};
template<SparseAttnFwdMode FWD_MODE>
inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwdMode::DecodeWithSplitKV>::value;
template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
================================================
FILE: csrc/sm100/decode/head128/README.md
================================================
Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel
================================================
FILE: csrc/sm100/decode/head64/config.h
================================================
#pragma once
#include "kernel.h"
#include <cuda_fp8.h>
#include <cutlass/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::decode::head64 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using e8m0 = __nv_fp8_e8m0;
using e4m3 = cutlass::float_e4m3_t;
using namespace cute;
enum NamedBarriers : uint32_t {
main_loop_sync = 0,
wg0_sync = 1,
wg0_warp02_sync = 2,
wg0_warp13_sync = 3,
everyone_sync = 4
};
template<ModelType MODEL_TYPE>
struct KernelTemplate {
static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int D_K = D_Q;
static constexpr int D_V = 512;
static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;
static constexpr int D_ROPE = 64;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;
static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included
static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
static_assert(D_NOPE + D_ROPE == D_Q);
static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64;
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales
static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant
static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
static constexpr int D_Q_SW128 = 512;
static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;
static_assert(D_Q_SW128 + D_Q_SW64 == D_Q);
static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes
template<
typename Shape_Q_SW128, typename TMA_Q_SW128,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope;
CUtensorMap tensor_map_extra_kv_rope;
};
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 256 + 64*D_Q/256: Q
// 400 ~ 464: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;
static constexpr int P = 400;
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<NUM_TILES*64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOBuf_TMA = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64>>{}
)); // A TMA tile
static_assert(D_V == 512);
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW128 = decltype(composition(
SmemLayoutKTiles_SW128<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW64 = decltype(composition(
SmemLayoutKTiles_SW64<NUM_TILES>{},
Layout<
Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;
bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
} o;
} qo;
struct {
struct {
array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized
array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode
} dequant[NUM_BUFS];
static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope
array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part
} kv;
} u;
union {
float4 p_exchange_buf[4][16 * B_TOPK / 4];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
} s_p;
CUTE_ALIGNAS(16) float rowwise_max_buf[128];
char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];
int tma_coord[NUM_INDEX_BUFS][B_TOPK];
e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];
array_aligned<uint32_t, 1> tmem_start_addr;
transac_bar_t bar_last_store_done;
transac_bar_t bar_q_tma, bar_q_utccp;
transac_bar_t bar_rope_ready[NUM_BUFS];
transac_bar_t bar_nope_ready[NUM_BUFS];
transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];
transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
template<typename TmaParam>
static __device__ void
flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params);
static void run(const SparseAttnDecodeParams ¶ms);
};
}
================================================
FILE: csrc/sm100/decode/head64/instantiations/model1.cu
================================================
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams ¶ms);
}
================================================
FILE: csrc/sm100/decode/head64/instantiations/v32.cu
================================================
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams ¶ms);
}
================================================
FILE: csrc/sm100/decode/head64/kernel.cuh
================================================
#include "kernel.h"
#include <math_constants.h>
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include "kerutils/kerutils.cuh"
#include "utils.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
template<typename TmaParam>
__device__ void
KernelTemplate<MODEL_TYPE>
::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params) {
#if defined(KERUTILS_ENABLE_SM100A)
const int s_q_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);
}
if (warp_idx == 0) {
if (elect_one_sync()) {
plan.bar_last_store_done.init(128);
plan.bar_q_tma.init(1);
plan.bar_q_utccp.init(1);
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_rope_ready[i].init(1);
plan.bar_nope_ready[i].init(128);
plan.bar_raw_ready[i].init(1);
plan.bar_raw_free[i].init(128);
plan.bar_qk_done[i].init(1);
plan.bar_so_ready[i].init(128);
plan.bar_sv_done[i].init(1);
}
for (int i = 0; i < NUM_INDEX_BUFS; ++i) {
plan.bar_valid_coord_scale_ready[i].init(32);
plan.bar_valid_coord_scale_free[i].init(128+128+1+1);
}
cutlass::arch::fence_barrier_init();
}
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
struct MainLoopArgs {
int batch_idx, start_block_idx, end_block_idx;
bool is_no_split; int n_split_idx;
bool bar_phase_batch_rel; // Bar phase of barriers that are used once per batch
int topk_length, extra_topk_length, num_orig_kv_blocks;
bool is_last_batch;
};
auto run_main_loop = [&](auto f) {
// NOTE Putting the following code outside the warpgroup specialization switch results in register spilling.
// [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted;
DecodingSchedMeta sched_meta;
KU_LDG_256(
params.tile_scheduler_metadata_ptr + partition_idx,
&sched_meta,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
if (sched_meta.begin_req_idx >= params.b) {
return;
}
bool bar_phase_batch_rel = 0;
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;
bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);
MainLoopArgs args = {
batch_idx, start_block_idx, end_block_idx,
!is_split, n_split_idx,
bar_phase_batch_rel,
topk_length, extra_topk_length,
orig_topk_padded / B_TOPK,
batch_idx == sched_meta.end_req_idx
};
f(args);
NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned();
}
};
struct RingState {
int buf_idx = 0;
bool bar_phase = 0;
int index_buf_idx = 0;
bool index_bar_phase = 0;
CUTE_DEVICE void update() {
bar_phase ^= (buf_idx == NUM_BUFS-1);
buf_idx = (buf_idx+1) % NUM_BUFS;
index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1);
index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS;
}
};
RingState rs;
if (warpgroup_idx == 0) {
// Scale & Exp warpgroup
// The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel
cutlass::arch::warpgroup_reg_alloc<224>();
constexpr int B_EPI = 64; // Must be equal to the size of the swizzle atom
Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{});
bf16* sO_bases[B_EPI/8]; // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i)
sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8);
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F;
run_main_loop([&](const MainLoopArgs &args) {
cute::tma_store_wait<0>();
plan.bar_last_store_done.arrive();
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); // Put the barrier wait here for more code reordering space
plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
// Load P
float p[B_TOPK/2], p_peer[B_TOPK/2];
if (warp_idx < 2) {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p_peer);
} else {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p_peer);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p);
}
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
// Reduce within shared mem
{
// Store
// Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i)
plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4);
NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3
// Load
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i) {
float2 t[2];
*(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx];
float2* cur_p = (float2*)(p + i*4);
cur_p[0] = ku::float2_add(cur_p[0], t[0]);
cur_p[1] = ku::float2_add(cur_p[1], t[1]);
}
}
// Since dual gemm is utilized, the layout of P in register now look like:
//
// 32 32
// +-------+-------+
// | | |
// 32 | Warp0 | Warp2 |
// | | |
// +-------+-------+
// | | |
// 32 | Warp1 | Warp3 |
// | | |
// +-------+-------+
// Mask
uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0));
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2; i += 1) {
if (!(valid_mask>>i&1))
p[i] = -CUDART_INF_F;
}
// Get rowwise max of Pi
float cur_pi_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2); i += 1) {
cur_pi_max = max(cur_pi_max, p[i]);
}
cur_pi_max *= params.sm_scale_div_log2;
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // This also separates "reading p_exchange_buf" and "writing S"
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
__nv_bfloat162 s[(B_TOPK/2)/2];
float2 neg_new_max = float2 {-new_max, -new_max};
float2 cur_sum = float2 {0.0f, 0.0f};
CUTE_UNROLL
gitextract_clsc5nbn/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── benchmark/
│ ├── bench_flash_mla.py
│ └── visualize.py
├── csrc/
│ ├── api/
│ │ ├── api.cpp
│ │ ├── common.h
│ │ ├── dense_decode.h
│ │ ├── dense_fwd.h
│ │ ├── sparse_decode.h
│ │ └── sparse_fwd.h
│ ├── defines.h
│ ├── kerutils/
│ │ └── include/
│ │ └── kerutils/
│ │ ├── common/
│ │ │ └── common.h
│ │ ├── device/
│ │ │ ├── common.h
│ │ │ ├── device.cuh
│ │ │ ├── sm100/
│ │ │ │ ├── gemm.cuh
│ │ │ │ ├── helpers.cuh
│ │ │ │ ├── intrinsics.cuh
│ │ │ │ └── tma_cta_group2_nosplit.cuh
│ │ │ ├── sm80/
│ │ │ │ ├── helpers.cuh
│ │ │ │ └── intrinsics.cuh
│ │ │ └── sm90/
│ │ │ ├── helpers.cuh
│ │ │ └── intrinsics.cuh
│ │ ├── host/
│ │ │ └── host.h
│ │ ├── kerutils.cuh
│ │ └── supplemental/
│ │ └── torch_tensors.h
│ ├── params.h
│ ├── sm100/
│ │ ├── decode/
│ │ │ ├── head128/
│ │ │ │ └── README.md
│ │ │ └── head64/
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── model1.cu
│ │ │ │ └── v32.cu
│ │ │ ├── kernel.cuh
│ │ │ └── kernel.h
│ │ ├── helpers.h
│ │ └── prefill/
│ │ ├── dense/
│ │ │ ├── collective/
│ │ │ │ ├── fmha_common.hpp
│ │ │ │ ├── fmha_fusion.hpp
│ │ │ │ ├── sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_load_tma_warpspecialized.hpp
│ │ │ │ ├── sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
│ │ │ │ └── sm100_fmha_mla_load_tma_warpspecialized.hpp
│ │ │ ├── common/
│ │ │ │ ├── gather_tensor.hpp
│ │ │ │ ├── helper.h
│ │ │ │ ├── mask.cuh
│ │ │ │ ├── pipeline_mla.hpp
│ │ │ │ ├── pow_2.hpp
│ │ │ │ └── utils.hpp
│ │ │ ├── device/
│ │ │ │ ├── fmha.hpp
│ │ │ │ └── fmha_device_bwd.hpp
│ │ │ ├── fmha_cutlass_bwd_sm100.cu
│ │ │ ├── fmha_cutlass_bwd_sm100.cuh
│ │ │ ├── fmha_cutlass_fwd_sm100.cu
│ │ │ ├── fmha_cutlass_fwd_sm100.cuh
│ │ │ ├── interface.h
│ │ │ └── kernel/
│ │ │ ├── fmha_causal_tile_scheduler.hpp
│ │ │ ├── fmha_kernel_bwd_convert.hpp
│ │ │ ├── fmha_kernel_bwd_sum_OdO.hpp
│ │ │ ├── fmha_options.hpp
│ │ │ ├── fmha_tile_scheduler.hpp
│ │ │ ├── sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
│ │ │ ├── sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
│ │ │ └── sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
│ │ └── sparse/
│ │ ├── common_subroutine.h
│ │ ├── fwd/
│ │ │ ├── head128/
│ │ │ │ ├── config.h
│ │ │ │ ├── instantiations/
│ │ │ │ │ ├── phase1_k512.cu
│ │ │ │ │ └── phase1_k576.cu
│ │ │ │ ├── phase1.cuh
│ │ │ │ └── phase1.h
│ │ │ └── head64/
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── phase1_k512.cu
│ │ │ │ └── phase1_k576.cu
│ │ │ ├── phase1.cuh
│ │ │ └── phase1.h
│ │ └── fwd_for_small_topk/
│ │ └── head128/
│ │ ├── config.h
│ │ ├── instantiations/
│ │ │ ├── phase1_decode_k512.cu
│ │ │ └── phase1_prefill_k512.cu
│ │ ├── phase1.cuh
│ │ └── phase1.h
│ ├── sm90/
│ │ ├── decode/
│ │ │ ├── dense/
│ │ │ │ ├── config.h
│ │ │ │ ├── instantiations/
│ │ │ │ │ ├── bf16.cu
│ │ │ │ │ └── fp16.cu
│ │ │ │ ├── splitkv_mla.cuh
│ │ │ │ ├── splitkv_mla.h
│ │ │ │ └── traits.h
│ │ │ └── sparse_fp8/
│ │ │ ├── components/
│ │ │ │ ├── config.h
│ │ │ │ ├── dequant.h
│ │ │ │ └── helpers.h
│ │ │ ├── config.h
│ │ │ ├── instantiations/
│ │ │ │ ├── model1_persistent_h128.cu
│ │ │ │ ├── model1_persistent_h64.cu
│ │ │ │ ├── v32_persistent_h128.cu
│ │ │ │ └── v32_persistent_h64.cu
│ │ │ ├── splitkv_mla.cuh
│ │ │ └── splitkv_mla.h
│ │ ├── helpers.h
│ │ └── prefill/
│ │ └── sparse/
│ │ ├── config.h
│ │ ├── fwd.cu
│ │ ├── fwd.h
│ │ ├── instantiations/
│ │ │ ├── phase1_k512.cu
│ │ │ ├── phase1_k512_topklen.cu
│ │ │ ├── phase1_k576.cu
│ │ │ └── phase1_k576_topklen.cu
│ │ ├── phase1.cuh
│ │ └── phase1.h
│ ├── smxx/
│ │ └── decode/
│ │ ├── combine/
│ │ │ ├── combine.cu
│ │ │ └── combine.h
│ │ └── get_decoding_sched_meta/
│ │ ├── get_decoding_sched_meta.cu
│ │ └── get_decoding_sched_meta.h
│ └── utils.h
├── docs/
│ ├── 20250422-new-kernel-deep-dive.md
│ └── 20250929-hopper-fp8-sparse-deep-dive.md
├── flash_mla/
│ ├── __init__.py
│ └── flash_mla_interface.py
├── setup.py
└── tests/
├── kernelkit/
│ ├── .gitignore
│ ├── __init__.py
│ ├── bench.py
│ ├── compare.py
│ ├── generate.py
│ ├── precision.py
│ └── utils.py
├── lib.py
├── quant.py
├── ref.py
├── test_flash_mla_dense_decoding.py
├── test_flash_mla_sparse_decoding.py
├── test_flash_mla_sparse_prefill.py
└── test_fmha_sm100.py
SYMBOL INDEX (484 symbols across 72 files)
FILE: benchmark/bench_flash_mla.py
function scaled_dot_product_attention (line 15) | def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal...
function run_torch_mla (line 36) | def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size,...
function run_flash_mla (line 63) | def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size,...
function run_flash_infer (line 82) | def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_siz...
function _mla_attn_kernel (line 136) | def _mla_attn_kernel(
function _mla_attn (line 222) | def _mla_attn(
function _mla_softmax_reducev_kernel (line 274) | def _mla_softmax_reducev_kernel(
function _mla_softmax_reducev (line 323) | def _mla_softmax_reducev(
function mla_decode_triton (line 346) | def mla_decode_triton(
function run_flash_mla_triton (line 381) | def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, bloc...
function compare_ab (line 410) | def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv...
function compare_a (line 450) | def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, d...
function get_args (line 493) | def get_args():
FILE: benchmark/visualize.py
function parse_args (line 7) | def parse_args():
FILE: csrc/api/api.cpp
function PYBIND11_MODULE (line 8) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/api/common.h
function is_sm90a (line 21) | struct Arch {
function int64_stride_to_int (line 44) | inline int int64_stride_to_int(int64_t orig_stride) {
function get_enum_max (line 124) | constexpr std::size_t get_enum_max(){
function std (line 133) | constexpr std::string get_dynamic_enum_name(T value){
function virtual (line 171) | constexpr virtual inline std::span<const FeatureT> get_supported_feature...
FILE: csrc/api/sparse_decode.h
type class (line 14) | enum class
type DecodeImplMeta (line 30) | struct DecodeImplMeta {
FILE: csrc/api/sparse_fwd.h
type class (line 12) | enum class
function class (line 29) | class Fwd_Sm90_Impl : public FwdImplBase {
function class (line 50) | class Fwd_Sm100_Head64_Impl : public FwdImplBase {
function class (line 68) | class Fwd_Sm100_Head128_Impl : public FwdImplBase {
function class (line 86) | class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
FILE: csrc/defines.h
type int32x8_t (line 13) | struct int32x8_t {
type float8 (line 17) | struct float8 {
type bf16x8 (line 21) | struct bf16x8 {
FILE: csrc/kerutils/include/kerutils/common/common.h
function namespace (line 3) | namespace kerutils {}
FILE: csrc/kerutils/include/kerutils/device/common.h
type class (line 16) | enum class
function PrefetchSize (line 25) | enum class PrefetchSize {
FILE: csrc/kerutils/include/kerutils/host/host.h
function namespace (line 15) | namespace kerutils {
FILE: csrc/kerutils/include/kerutils/supplemental/torch_tensors.h
function namespace (line 9) | namespace kerutils {
FILE: csrc/params.h
function ModelType (line 5) | enum class ModelType {
type DenseAttnDecodeParams (line 19) | struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
type SparseAttnDecodeParams (line 63) | struct SparseAttnDecodeParams {
type CombineParams (line 105) | struct CombineParams {
type GetDecodeSchedMetaParams (line 127) | struct GetDecodeSchedMetaParams {
type SparseAttnFwdParams (line 145) | struct SparseAttnFwdParams {
function SparseAttnFwdMode (line 171) | enum class SparseAttnFwdMode {
FILE: csrc/sm100/decode/head64/config.h
function namespace (line 14) | namespace sm100::decode::head64 {
FILE: csrc/sm100/decode/head64/kernel.h
function namespace (line 5) | namespace sm100::decode::head64 {
FILE: csrc/sm100/helpers.h
function namespace (line 9) | namespace sm100 {
FILE: csrc/sm100/prefill/dense/collective/fmha_common.hpp
type cutlass::fmha::collective (line 37) | namespace cutlass::fmha::collective {
function CUTE_DEVICE (line 42) | CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB cons...
function CUTE_DEVICE (line 56) | CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB,...
function CUTE_DEVICE (line 62) | CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Sta...
function CUTE_DEVICE (line 67) | CUTE_DEVICE T warp_uniform(T a) {
FILE: csrc/sm100/prefill/dense/collective/fmha_fusion.hpp
type cutlass::fmha::collective (line 37) | namespace cutlass::fmha::collective {
type NoMask (line 41) | struct NoMask {
method CUTLASS_DEVICE (line 43) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 53) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 63) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 73) | CUTLASS_DEVICE
type ResidualMask (line 83) | struct ResidualMask : NoMask {
method CUTLASS_DEVICE (line 88) | CUTLASS_DEVICE int get_masked_trip_count(
method CUTLASS_DEVICE (line 100) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 114) | CUTLASS_DEVICE
type ResidualMaskForBackward (line 135) | struct ResidualMaskForBackward : NoMask {
method CUTLASS_DEVICE (line 140) | CUTLASS_DEVICE int get_masked_trip_count(
method CUTLASS_DEVICE (line 152) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 166) | CUTLASS_DEVICE
type CausalMask (line 191) | struct CausalMask : NoMask {
method CUTLASS_DEVICE (line 198) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 218) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 234) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 244) | CUTLASS_DEVICE
type CausalForBackwardMask (line 280) | struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForB...
method CUTLASS_DEVICE (line 285) | CUTLASS_DEVICE
type VariableLength (line 316) | struct VariableLength {
method CUTE_HOST_DEVICE (line 321) | CUTE_HOST_DEVICE operator int() const {
type is_variable_length_impl (line 326) | struct is_variable_length_impl : std::false_type {}
type is_variable_length_impl<VariableLength> (line 327) | struct is_variable_length_impl<VariableLength> : std::true_type {}
function CUTE_HOST_DEVICE (line 331) | CUTE_HOST_DEVICE
function CUTE_HOST_DEVICE (line 345) | CUTE_HOST_DEVICE
function CUTE_HOST_DEVICE (line 361) | CUTE_HOST_DEVICE
type cute (line 386) | namespace cute {
type is_integral<cutlass::fmha::collective::VariableLength> (line 389) | struct is_integral<cutlass::fmha::collective::VariableLength> : true_t...
function CUTE_HOST_DEVICE (line 391) | CUTE_HOST_DEVICE
FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
type cutlass::fmha::collective (line 38) | namespace cutlass::fmha::collective {
type Sm100FmhaFwdEpilogueTmaWarpspecialized (line 48) | struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
type TensorStorage (line 64) | struct TensorStorage {
type Arguments (line 71) | struct Arguments {
type Params (line 86) | struct Params {
method CUTLASS_DEVICE (line 96) | CUTLASS_DEVICE static constexpr
method Params (line 107) | static Params to_underlying_arguments(
method CUTLASS_DEVICE (line 145) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 152) | CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& ...
method store (line 155) | CUTLASS_DEVICE auto
FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
type cutlass::fmha::collective (line 44) | namespace cutlass::fmha::collective {
type Sm100FmhaFwdMainloopTmaWarpspecialized (line 65) | struct Sm100FmhaFwdMainloopTmaWarpspecialized {
type TensorStorage (line 113) | struct TensorStorage {
type TmemAllocation (line 121) | enum class TmemAllocation : uint32_t {
type Arguments (line 187) | struct Arguments {
type Params (line 202) | struct Params {
method can_implement (line 212) | static bool can_implement(ProblemShape const& problem_shape, Argumen...
method Params (line 217) | static Params to_underlying_arguments(
method CUTLASS_DEVICE (line 236) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 242) | CUTLASS_DEVICE void
method mma (line 258) | CUTLASS_DEVICE auto
method softmax_step (line 514) | CUTLASS_DEVICE auto
method softmax (line 714) | CUTLASS_DEVICE auto
method correction_epilogue (line 778) | CUTLASS_DEVICE auto
method correction_rescale (line 868) | CUTLASS_DEVICE auto
method correction (line 954) | CUTLASS_DEVICE auto
method correction_empty (line 1142) | CUTLASS_DEVICE auto
FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp
type cutlass::fmha::collective (line 42) | namespace cutlass::fmha::collective {
type Sm100FmhaLoadTmaWarpspecialized (line 62) | struct Sm100FmhaLoadTmaWarpspecialized {
type Arguments (line 67) | struct Arguments {
type Params (line 80) | struct Params {
method Params (line 87) | static Params to_underlying_arguments(
method CUTLASS_DEVICE (line 141) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 149) | CUTLASS_DEVICE void
FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
type cutlass::fmha::collective (line 45) | namespace cutlass::fmha::collective {
type Sm100MlaFwdMainloopTmaWarpspecialized (line 65) | struct Sm100MlaFwdMainloopTmaWarpspecialized {
type TensorStorageQKVO (line 127) | struct TensorStorageQKVO {
type TensorStorageQKV (line 134) | struct TensorStorageQKV {
type TmemAllocation (line 142) | enum class TmemAllocation : uint32_t {
type Arguments (line 205) | struct Arguments {
type Params (line 220) | struct Params {
method can_implement (line 230) | static bool can_implement(ProblemShape const& problem_shape, Argumen...
method Params (line 235) | static Params to_underlying_arguments(
method CUTLASS_DEVICE (line 254) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 260) | CUTLASS_DEVICE void
method mma (line 276) | CUTLASS_DEVICE auto
method softmax_step (line 532) | CUTLASS_DEVICE auto
method softmax (line 735) | CUTLASS_DEVICE auto
method correction_epilogue (line 786) | CUTLASS_DEVICE auto
method correction_rescale (line 876) | CUTLASS_DEVICE auto
method correction (line 962) | CUTLASS_DEVICE auto
method correction_empty (line 1149) | CUTLASS_DEVICE auto
FILE: csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
type cutlass::fmha::collective (line 42) | namespace cutlass::fmha::collective {
type Sm100MlaFwdLoadTmaWarpspecialized (line 63) | struct Sm100MlaFwdLoadTmaWarpspecialized {
type Arguments (line 74) | struct Arguments {
type Params (line 87) | struct Params {
method Params (line 94) | static Params to_underlying_arguments(
method CUTLASS_DEVICE (line 149) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 157) | CUTLASS_DEVICE void
FILE: csrc/sm100/prefill/dense/common/gather_tensor.hpp
type example (line 37) | namespace example {
type NoGather (line 42) | struct NoGather
method NoGather (line 45) | NoGather(Ts...) {}
type IndexedGather (line 50) | struct IndexedGather
method CUTE_HOST_DEVICE (line 52) | CUTE_HOST_DEVICE constexpr
method CUTE_HOST_DEVICE (line 56) | CUTE_HOST_DEVICE constexpr
method print (line 61) | void
type StridedGather (line 72) | struct StridedGather
method CUTE_HOST_DEVICE (line 74) | CUTE_HOST_DEVICE constexpr
method CUTE_HOST_DEVICE (line 78) | CUTE_HOST_DEVICE constexpr
method print (line 83) | void
type CustomStride (line 95) | struct CustomStride
method CUTE_HOST_DEVICE (line 101) | CUTE_HOST_DEVICE constexpr friend
method CUTE_HOST_DEVICE (line 106) | CUTE_HOST_DEVICE constexpr friend
method print (line 111) | void
method CUTE_HOST_DEVICE (line 121) | CUTE_HOST_DEVICE constexpr friend
method CUTE_HOST_DEVICE (line 130) | CUTE_HOST_DEVICE constexpr friend
function make_custom_stride_layout (line 142) | CUTLASS_HOST_DEVICE
function make_gather_tensor (line 155) | CUTLASS_HOST_DEVICE
type cute (line 171) | namespace cute
function CUTE_HOST_DEVICE (line 175) | CUTE_HOST_DEVICE constexpr
function CUTE_HOST_DEVICE (line 195) | CUTE_HOST_DEVICE constexpr
FILE: csrc/sm100/prefill/dense/common/pipeline_mla.hpp
type cutlass (line 40) | namespace cutlass {
class PipelineTmaAsyncMla (line 49) | class PipelineTmaAsyncMla {
method CUTLASS_DEVICE (line 72) | static
method CUTLASS_DEVICE (line 90) | static
method CUTLASS_DEVICE (line 110) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 119) | CUTLASS_DEVICE
method if (line 142) | if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
method if (line 147) | if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
function CUTLASS_DEVICE (line 171) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 176) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 198) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 203) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 208) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 213) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 228) | CUTLASS_DEVICE
FILE: csrc/sm100/prefill/dense/common/pow_2.hpp
type cutlass::fmha (line 39) | namespace cutlass::fmha {
type Pow2 (line 41) | struct Pow2 {
method CUTE_HOST_DEVICE (line 52) | CUTE_HOST_DEVICE T operator *(T const& b) const {
function CUTE_HOST_DEVICE (line 77) | CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
function CUTE_HOST_DEVICE (line 81) | CUTE_HOST_DEVICE void print(Pow2 const& a) {
type cute (line 87) | namespace cute {
type is_integral<cutlass::fmha::Pow2> (line 90) | struct is_integral<cutlass::fmha::Pow2> : true_type {}
FILE: csrc/sm100/prefill/dense/common/utils.hpp
type cutlass_dtype (line 8) | struct cutlass_dtype {
type cutlass_dtype<half> (line 13) | struct cutlass_dtype<half> {
type cutlass_dtype<nv_bfloat16> (line 18) | struct cutlass_dtype<nv_bfloat16> {
type cutlass_dtype<__nv_fp8_e4m3> (line 23) | struct cutlass_dtype<__nv_fp8_e4m3> {
type cutlass_dtype<__nv_fp8_e5m2> (line 28) | struct cutlass_dtype<__nv_fp8_e5m2> {
FILE: csrc/sm100/prefill/dense/device/fmha.hpp
type cutlass::fmha::device (line 49) | namespace cutlass::fmha::device {
class FMHA (line 56) | class FMHA {
method is_initialized (line 72) | bool is_initialized(bool set = false) {
method Params (line 81) | Params const& params() const {
method Status (line 86) | static Status
method get_workspace_size (line 97) | static size_t
method dim3 (line 105) | static dim3
method maximum_active_blocks (line 111) | static int maximum_active_blocks(int /* smem_capacity */ = -1) {
method Status (line 153) | Status
method Status (line 190) | Status
method Status (line 205) | static Status
method Status (line 249) | Status
method Status (line 259) | Status
method Status (line 265) | Status
method Status (line 271) | Status
FILE: csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp
type cutlass::fmha::device (line 48) | namespace cutlass::fmha::device {
class Sm100FmhaBwd (line 62) | class Sm100FmhaBwd {
type Arguments (line 65) | struct Arguments {
type Params (line 119) | struct Params {
method to_sum_OdO_arguments (line 130) | static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
method to_convert_arguments (line 153) | static typename OperationConvert::Arguments to_convert_arguments(Arg...
method to_bwd_arguments (line 172) | static typename Operation::Arguments to_bwd_arguments(
method Status (line 197) | static Status
method get_workspace_size (line 220) | static size_t
method Status (line 237) | Status
method Status (line 266) | Status
method Status (line 286) | static Status
method Status (line 319) | Status
method Status (line 329) | Status
FILE: csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp
type cutlass::fmha::kernel (line 38) | namespace cutlass::fmha::kernel {
type CausalIndividualTileScheduler (line 45) | struct CausalIndividualTileScheduler {
type Params (line 51) | struct Params {
method CUTLASS_DEVICE (line 62) | CUTLASS_DEVICE
method Params (line 66) | static Params to_underlying_arguments(
method dim3 (line 78) | static dim3 get_grid_shape(Params const& params) {
method CUTLASS_DEVICE (line 82) | CUTLASS_DEVICE
method get_block_coord (line 87) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 112) | CUTLASS_DEVICE
type CausalPersistentTileScheduler (line 125) | struct CausalPersistentTileScheduler {
type Params (line 127) | struct Params {
method Params (line 143) | static Params to_underlying_arguments(
method dim3 (line 168) | static dim3 get_grid_shape(Params const& params) {
method CUTLASS_DEVICE (line 173) | CUTLASS_DEVICE
method get_block_coord (line 178) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 189) | CUTLASS_DEVICE
FILE: csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp
type cutlass::fmha::kernel (line 39) | namespace cutlass::fmha::kernel {
type FmhaKernelBwdConvert (line 44) | struct FmhaKernelBwdConvert {
type Arguments (line 46) | struct Arguments {
method get_workspace_size (line 77) | static size_t get_workspace_size(Arguments const& args) { return 0; }
method initialize_workspace (line 78) | static cutlass::Status initialize_workspace(Arguments const&, void*,...
method can_implement (line 88) | static bool can_implement(Arguments const& args) {
method dim3 (line 92) | static dim3 get_grid_shape(Params const& params) {
method dim3 (line 97) | static dim3 get_block_shape() {
method Params (line 102) | static Params to_underlying_arguments(Arguments const& args, void* w...
method CUTLASS_DEVICE (line 107) | CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr...
method CUTLASS_DEVICE (line 141) | CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
FILE: csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp
type cutlass::fmha::kernel (line 39) | namespace cutlass::fmha::kernel {
type FmhaKernelBwdSumOdO (line 44) | struct FmhaKernelBwdSumOdO {
type Arguments (line 46) | struct Arguments {
method get_workspace_size (line 76) | static size_t get_workspace_size(Arguments const& args) { return 0; }
method initialize_workspace (line 77) | static cutlass::Status initialize_workspace(Arguments const&, void*,...
method can_implement (line 89) | static bool can_implement(Arguments const& args) {
method dim3 (line 93) | static dim3 get_grid_shape(Params const& params) {
method dim3 (line 98) | static dim3 get_block_shape() {
method Params (line 103) | static Params to_underlying_arguments(Arguments const& args, void* w...
method CUTLASS_DEVICE (line 107) | CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
FILE: csrc/sm100/prefill/dense/kernel/fmha_options.hpp
type cutlass::fmha::kernel (line 38) | namespace cutlass::fmha::kernel {
type find_option (line 41) | struct find_option
type find_option<kTag, Default> (line 44) | struct find_option<kTag, Default> {
type Tag (line 60) | enum class Tag {
type Option (line 80) | struct Option {
type find_option<kTag, Default, Option, Options...> (line 49) | struct find_option<kTag, Default, Option, Options...> :
FILE: csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp
type cutlass::fmha::kernel (line 40) | namespace cutlass::fmha::kernel {
type IndividualTileScheduler (line 44) | struct IndividualTileScheduler {
type Params (line 46) | struct Params {
method CUTLASS_DEVICE (line 52) | CUTLASS_DEVICE
method Params (line 56) | static Params to_underlying_arguments(
method dim3 (line 64) | static dim3 get_grid_shape(Params const& params) {
method CUTLASS_DEVICE (line 68) | CUTLASS_DEVICE
method get_block_coord (line 73) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 79) | CUTLASS_DEVICE
type PersistentTileScheduler (line 88) | struct PersistentTileScheduler {
type Params (line 90) | struct Params {
method Params (line 106) | static Params to_underlying_arguments(
method dim3 (line 131) | static dim3 get_grid_shape(Params const& params) {
method CUTLASS_DEVICE (line 136) | CUTLASS_DEVICE
method get_block_coord (line 141) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 152) | CUTLASS_DEVICE
FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
type cutlass::fmha::kernel (line 49) | namespace cutlass::fmha::kernel {
type Sm100FmhaBwdKernelTmaWarpSpecialized (line 62) | struct Sm100FmhaBwdKernelTmaWarpSpecialized {
type TmemAllocation (line 72) | struct TmemAllocation {
type WarpRole (line 87) | enum class WarpRole {
method CUTLASS_DEVICE (line 94) | CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
type RegisterAllocation (line 98) | struct RegisterAllocation {
type PipelineStorage (line 204) | struct PipelineStorage {
method CUTE_DEVICE (line 218) | static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stag...
type TensorStorage (line 242) | struct TensorStorage {
type SharedStorage (line 271) | struct SharedStorage {
type MainloopArguments (line 284) | struct MainloopArguments {
type MainloopParams (line 316) | struct MainloopParams {
type EpilogueArguments (line 324) | struct EpilogueArguments {
type Arguments (line 331) | struct Arguments {
type Params (line 338) | struct Params {
method can_implement (line 347) | static bool can_implement(Arguments const& args) {
method Status (line 360) | static Status initialize_workspace(Arguments const&, void*, cudaStre...
method Params (line 365) | static Params to_underlying_arguments(Arguments const& args, void*) {
method quantize (line 414) | static CUTLASS_DEVICE auto quantize(T const& input) {
method CUTLASS_DEVICE (line 432) | CUTLASS_DEVICE void load(
method CUTLASS_DEVICE (line 661) | CUTLASS_DEVICE void mma(
method CUTLASS_DEVICE (line 946) | CUTLASS_DEVICE void store(
method CUTLASS_DEVICE (line 971) | CUTLASS_DEVICE void epilogue_clear(
method CUTLASS_DEVICE (line 1015) | CUTLASS_DEVICE void epilogue(
method CUTLASS_DEVICE (line 1119) | CUTLASS_DEVICE void compute(
method CUTLASS_DEVICE (line 1392) | CUTLASS_DEVICE void reduce(
method CUTLASS_DEVICE (line 1489) | CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
method dim3 (line 1822) | static dim3 get_block_shape() {
method dim3 (line 1827) | static dim3 get_grid_shape(Params const& params) {
FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
type cutlass::fmha::kernel (line 49) | namespace cutlass::fmha::kernel {
type Sm100FmhaBwdMlaKernelTmaWarpSpecialized (line 62) | struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
type TmemAllocation (line 70) | struct TmemAllocation {
type WarpRole (line 85) | enum class WarpRole {
method CUTLASS_DEVICE (line 95) | CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
type RegisterAllocation (line 99) | struct RegisterAllocation {
type PipelineStorage (line 205) | struct PipelineStorage {
method CUTE_DEVICE (line 219) | static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stag...
type TensorStorage (line 245) | struct TensorStorage {
type SharedStorage (line 278) | struct SharedStorage {
type MainloopArguments (line 291) | struct MainloopArguments {
type MainloopParams (line 323) | struct MainloopParams {
type EpilogueArguments (line 331) | struct EpilogueArguments {
type Arguments (line 338) | struct Arguments {
type Params (line 345) | struct Params {
method can_implement (line 354) | static bool can_implement(Arguments const& args) {
method Status (line 367) | static Status initialize_workspace(Arguments const&, void*, cudaStre...
method Params (line 372) | static Params to_underlying_arguments(Arguments const& args, void*) {
method quantize (line 421) | static CUTLASS_DEVICE auto quantize(T const& input) {
method CUTLASS_DEVICE (line 439) | CUTLASS_DEVICE void load(
method CUTLASS_DEVICE (line 667) | CUTLASS_DEVICE void mma(
method CUTLASS_DEVICE (line 951) | CUTLASS_DEVICE void store(
method CUTLASS_DEVICE (line 976) | CUTLASS_DEVICE void epilogue_clear(
method CUTLASS_DEVICE (line 1021) | CUTLASS_DEVICE void epilogue(
method CUTLASS_DEVICE (line 1125) | CUTLASS_DEVICE void compute(
method CUTLASS_DEVICE (line 1386) | CUTLASS_DEVICE void reduce(
method CUTLASS_DEVICE (line 1483) | CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
method dim3 (line 1816) | static dim3 get_block_shape() {
method dim3 (line 1821) | static dim3 get_grid_shape(Params const& params) {
FILE: csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
type cutlass::fmha::kernel (line 47) | namespace cutlass::fmha::kernel {
type Sm100FmhaCtxKernelWarpspecializedSchedule (line 52) | struct Sm100FmhaCtxKernelWarpspecializedSchedule {
type WarpRole (line 54) | enum class WarpRole {
method WarpRole (line 64) | static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
type Sm100MlaFwdCtxKernelWarpspecializedSchedule (line 91) | struct Sm100MlaFwdCtxKernelWarpspecializedSchedule {
type WarpRole (line 93) | enum class WarpRole {
method WarpRole (line 103) | static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
type Sm100FmhaFwdKernelTmaWarpspecialized (line 136) | struct Sm100FmhaFwdKernelTmaWarpspecialized {
method WarpRole (line 143) | constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
type SharedStorage (line 168) | struct SharedStorage {
type PipelineStorage (line 188) | struct PipelineStorage {
type Arguments (line 205) | struct Arguments {
type Params (line 212) | struct Params {
method get_workspace_size (line 223) | static size_t get_workspace_size(Arguments const& args) { return 0; }
method initialize_workspace (line 224) | static cutlass::Status initialize_workspace(Arguments const&, void*,...
method can_implement (line 228) | static bool can_implement(Arguments const& args) {
method dim3 (line 232) | static dim3 get_grid_shape(Params const& params) {
method dim3 (line 236) | static dim3 get_block_shape() {
method Params (line 241) | static Params to_underlying_arguments(Arguments const& args, void* w...
method apply_batch (line 250) | CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape c...
method CUTLASS_DEVICE (line 254) | CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
FILE: csrc/sm100/prefill/sparse/common_subroutine.h
function namespace (line 6) | namespace sm100 {
FILE: csrc/sm100/prefill/sparse/fwd/head128/config.h
function namespace (line 10) | namespace sm100::fwd::head128 {
FILE: csrc/sm100/prefill/sparse/fwd/head128/phase1.h
function namespace (line 5) | namespace sm100::fwd::head128 {
FILE: csrc/sm100/prefill/sparse/fwd/head64/config.h
function namespace (line 8) | namespace sm100::fwd::head64 {
FILE: csrc/sm100/prefill/sparse/fwd/head64/phase1.h
function namespace (line 5) | namespace sm100::fwd::head64 {
FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h
function namespace (line 12) | namespace sm100::fwd_for_small_topk::head128 {
FILE: csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h
function namespace (line 5) | namespace sm100::fwd_for_small_topk::head128 {
FILE: csrc/sm90/decode/dense/config.h
function namespace (line 3) | namespace Config {
FILE: csrc/sm90/decode/dense/splitkv_mla.h
function namespace (line 5) | namespace sm90 {
FILE: csrc/sm90/decode/dense/traits.h
type SharedMemoryPlan (line 71) | struct SharedMemoryPlan {
type NamedBarriers (line 101) | enum NamedBarriers : int {
FILE: csrc/sm90/decode/sparse_fp8/components/config.h
function namespace (line 10) | namespace sm90::decode::sparse_fp8 {
FILE: csrc/sm90/decode/sparse_fp8/components/dequant.h
type fp8x8 (line 10) | struct fp8x8 {
type fp8x16 (line 15) | struct fp8x16 {
function bf16x8 (line 21) | bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale...
type class (line 36) | enum class
function L2PrefetchHint (line 43) | enum class L2PrefetchHint {
FILE: csrc/sm90/decode/sparse_fp8/components/helpers.h
function namespace (line 10) | namespace sm90::decode::sparse_fp8 {
function st_async_128b (line 79) | void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mb...
function CUTE_DEVICE (line 90) | CUTE_DEVICE
FILE: csrc/sm90/decode/sparse_fp8/config.h
function namespace (line 13) | namespace sm90::decode::sparse_fp8 {
FILE: csrc/sm90/decode/sparse_fp8/splitkv_mla.h
function namespace (line 5) | namespace sm90::decode::sparse_fp8 {
FILE: csrc/sm90/helpers.h
function namespace (line 6) | namespace sm90 {
FILE: csrc/sm90/prefill/sparse/config.h
function namespace (line 14) | namespace sm90::fwd {
FILE: csrc/sm90/prefill/sparse/fwd.h
function namespace (line 5) | namespace sm90 {
FILE: csrc/sm90/prefill/sparse/phase1.h
function namespace (line 5) | namespace sm90::fwd {
FILE: csrc/smxx/decode/combine/combine.h
function namespace (line 5) | namespace smxx::decode {
FILE: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
function namespace (line 5) | namespace smxx::decode {
FILE: csrc/utils.h
type RingBufferState (line 58) | struct RingBufferState {
function RingBufferState (line 75) | RingBufferState offset_by(const int offset) const {
FILE: flash_mla/flash_mla_interface.py
class FlashMLASchedMeta (line 9) | class FlashMLASchedMeta:
class Config (line 15) | class Config:
function get_mla_metadata (line 37) | def get_mla_metadata(
function flash_mla_with_kvcache (line 53) | def flash_mla_with_kvcache(
function flash_mla_sparse_fwd (line 176) | def flash_mla_sparse_fwd(
function _flash_attn_varlen_forward (line 214) | def _flash_attn_varlen_forward(
function _flash_attn_varlen_backward (line 261) | def _flash_attn_varlen_backward(
class FlashAttnVarlenFunc (line 328) | class FlashAttnVarlenFunc(torch.autograd.Function):
method forward (line 329) | def forward(
method backward (line 356) | def backward(
function flash_attn_varlen_func (line 372) | def flash_attn_varlen_func(
function flash_attn_varlen_qkvpacked_func (line 395) | def flash_attn_varlen_qkvpacked_func(
function flash_attn_varlen_kvpacked_func (line 415) | def flash_attn_varlen_kvpacked_func(
FILE: setup.py
function is_flag_set (line 16) | def is_flag_set(flag: str) -> bool:
function get_features_args (line 19) | def get_features_args():
function get_arch_flags (line 25) | def get_arch_flags():
function get_nvcc_thread_args (line 48) | def get_nvcc_thread_args():
FILE: tests/kernelkit/bench.py
class empty_suppress (line 9) | class empty_suppress:
method __enter__ (line 10) | def __enter__(self):
method __exit__ (line 13) | def __exit__(self, *_):
function profiler_range_start_marker_kernel (line 17) | def profiler_range_start_marker_kernel():
function _run_profiler_range_start_marker_kernel (line 20) | def _run_profiler_range_start_marker_kernel():
class BenchKinetoRawResult (line 24) | class BenchKinetoRawResult:
method _get_matched_kernel_name (line 33) | def _get_matched_kernel_name(self, name_substr: str, allow_no_match: b...
method get_kernel_names (line 42) | def get_kernel_names(self) -> List[str]:
method get_kernel_times (line 45) | def get_kernel_times(self, kernel_names_substr: List[str], allow_indiv...
method get_kernel_time (line 74) | def get_kernel_time(self, kernel_name_substr: str) -> float:
method get_e2e_time (line 77) | def get_e2e_time(self, start_kernel_name_substr: str, end_kenrel_name_...
function bench_kineto (line 103) | def bench_kineto(fn: Callable, num_tests: int = 30,
function bench_by_cuda_events (line 161) | def bench_by_cuda_events(kernels: List[Callable], num_warmups_each: int,...
function bench_by_cuda_events (line 164) | def bench_by_cuda_events(kernels: Callable, num_warmups_each: int, num_r...
function bench_by_cuda_events (line 166) | def bench_by_cuda_events(kernels: Union[List[Callable], Callable], num_w...
FILE: tests/kernelkit/compare.py
function check_is_bitwise_equal_comparator (line 5) | def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tens...
function check_is_bitwise_equal (line 13) | def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tens...
function get_cos_diff (line 19) | def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float:
function check_is_allclose (line 31) | def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, a...
function check_is_allclose_comparator (line 94) | def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torc...
FILE: tests/kernelkit/generate.py
function _get_new_non_contiguous_tensor_shape (line 3) | def _get_new_non_contiguous_tensor_shape(shape):
function gen_non_contiguous_randn_tensor (line 10) | def gen_non_contiguous_randn_tensor(shape, *args, **kwargs):
function gen_non_contiguous_tensor (line 16) | def gen_non_contiguous_tensor(shape, *args, **kwargs):
function non_contiguousify (line 22) | def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:
FILE: tests/kernelkit/precision.py
class LowPrecisionMode (line 5) | class LowPrecisionMode:
method __init__ (line 6) | def __init__(self, enabled: bool = True):
method __enter__ (line 9) | def __enter__(self):
method __exit__ (line 13) | def __exit__(self, exc_type, exc_value, traceback):
function is_low_precision_mode (line 17) | def is_low_precision_mode() -> bool:
function optional_cast_to_bf16_and_cast_back (line 23) | def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.T...
FILE: tests/kernelkit/utils.py
function cdiv (line 18) | def cdiv(a: int, b: int) -> int:
function is_using_profiling_tools (line 22) | def is_using_profiling_tools() -> bool:
function set_random_seed (line 33) | def set_random_seed(seed: int):
class Counter (line 44) | class Counter:
method __init__ (line 45) | def __init__(self):
method next (line 48) | def next(self) -> int:
FILE: tests/lib.py
class TestTarget (line 13) | class TestTarget(enum.Enum):
class ExtraTestParamForDecode (line 18) | class ExtraTestParamForDecode:
class TestParam (line 29) | class TestParam:
class RawTestParamForDecode (line 46) | class RawTestParamForDecode:
method to_test_param (line 74) | def to_test_param(self) -> TestParam:
class Testcase (line 90) | class Testcase:
function _randperm_batch (line 100) | def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size...
function generate_testcase (line 121) | def generate_testcase(t: TestParam) -> Testcase:
class KVScope (line 168) | class KVScope:
method quant_and_dequant_ (line 178) | def quant_and_dequant_(self):
method get_kvcache_for_flash_mla (line 195) | def get_kvcache_for_flash_mla(self) -> torch.Tensor:
method apply_perm (line 202) | def apply_perm(self, perm: torch.Tensor) -> "KVScope":
class TestcaseForDecode (line 219) | class TestcaseForDecode:
function generate_testcase_for_decode (line 227) | def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
function run_flash_mla_sparse_fwd (line 310) | def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bo...
function run_flash_mla_decode (line 319) | def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_schedu...
class FlopsAndMemVolStatistics (line 338) | class FlopsAndMemVolStatistics:
function count_flop_and_mem_vol (line 345) | def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolS...
class FlopsAndMemVolStatisticsForDecode (line 360) | class FlopsAndMemVolStatisticsForDecode:
function count_flop_and_mem_vol_for_decode (line 367) | def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode...
function is_no_cooldown (line 404) | def is_no_cooldown() -> bool:
FILE: tests/quant.py
class FP8KVCacheLayout (line 6) | class FP8KVCacheLayout(enum.Enum):
method get_meta (line 10) | def get_meta(self) -> Tuple[int, int, int, int, int]:
function _cast_scale_inv_to_ue8m0 (line 17) | def _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch...
function quantize_k_cache (line 20) | def quantize_k_cache(
function dequantize_k_cache (line 81) | def dequantize_k_cache(
function abs_indices2indices_in_kvcache (line 126) | def abs_indices2indices_in_kvcache(
FILE: tests/ref.py
function _merge_two_lse (line 7) | def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q...
function ref_sparse_attn_fwd (line 19) | def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor...
function ref_sparse_attn_decode (line 55) | def ref_sparse_attn_decode(
FILE: tests/test_flash_mla_dense_decoding.py
class TestParam (line 13) | class TestParam:
function generate_test_data (line 29) | def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor...
function reference_torch (line 73) | def reference_torch(
function test_flash_mla (line 145) | def test_flash_mla(t: TestParam):
function main (line 195) | def main(torch_dtype):
FILE: tests/test_flash_mla_sparse_decoding.py
function gen_testcase (line 23) | def gen_testcase() -> List[RawTestParam]:
class Result (line 130) | class Result:
function test_flash_mla (line 142) | def test_flash_mla(p: TestParam) -> Result:
function main (line 236) | def main():
FILE: tests/test_flash_mla_sparse_prefill.py
function run_test (line 14) | def run_test(p: TestParam) -> bool:
FILE: tests/test_fmha_sm100.py
function get_window_size (line 10) | def get_window_size(causal, window):
function get_attn_bias (line 18) | def get_attn_bias(s_q, s_k, causal, window):
function sdpa (line 31) | def sdpa(query, key, value, attn_bias, softmax_scale=None):
function sdpa_checkpoint (line 46) | def sdpa_checkpoint(*args, **kwargs):
function test_flash_attention (line 50) | def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, cau...
Condensed preview — 129 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,154K chars).
[
{
"path": ".gitignore",
"chars": 110,
"preview": "build\n*.so\n*.egg-info/\n__pycache__/\ndist/\n*perf.csv\n*.png\n/.vscode\ncompile_commands.json\n.cache\n/dev\n/.clangd\n"
},
{
"path": ".gitmodules",
"chars": 93,
"preview": "[submodule \"csrc/cutlass\"]\n\tpath = csrc/cutlass\n\turl = https://github.com/NVIDIA/cutlass.git\n"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2025 DeepSeek\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 10646,
"preview": "# FlashMLA\n\n## Introduction\n\nFlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](h"
},
{
"path": "benchmark/bench_flash_mla.py",
"chars": 19156,
"preview": "# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c"
},
{
"path": "benchmark/visualize.py",
"chars": 745,
"preview": "import argparse\n\nimport matplotlib.pyplot as plt\nimport pandas as pd\n\n\ndef parse_args():\n parser = argparse.ArgumentP"
},
{
"path": "csrc/api/api.cpp",
"chars": 507,
"preview": "#include <pybind11/pybind11.h>\n\n#include \"sparse_fwd.h\"\n#include \"sparse_decode.h\"\n#include \"dense_decode.h\"\n#include \"d"
},
{
"path": "csrc/api/common.h",
"chars": 8741,
"preview": "#pragma once\n\n#include <span>\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGu"
},
{
"path": "csrc/api/dense_decode.h",
"chars": 9821,
"preview": "#pragma once\n\n#include <cutlass/half.h>\n#include <cutlass/fast_math.h>\n\n#include \"common.h\"\n#include \"params.h\"\n\n#includ"
},
{
"path": "csrc/api/dense_fwd.h",
"chars": 78,
"preview": "#pragma once\n\n#include \"common.h\"\n\n#include \"sm100/prefill/dense/interface.h\"\n"
},
{
"path": "csrc/api/sparse_decode.h",
"chars": 18293,
"preview": "#pragma once\n\n#include \"common.h\"\n\n#include \"params.h\"\n\n#include \"sm90/decode/sparse_fp8/splitkv_mla.h\"\n#include \"sm100/"
},
{
"path": "csrc/api/sparse_fwd.h",
"chars": 7749,
"preview": "#pragma once\n\n#include \"common.h\"\n\n#include \"params.h\"\n\n#include \"sm90/prefill/sparse/phase1.h\"\n#include \"sm100/prefill/"
},
{
"path": "csrc/defines.h",
"chars": 564,
"preview": "#pragma once\n\n#include <cutlass/bfloat16.h>\n#include <cutlass/arch/barrier.h>\n\nusing bf16 = cutlass::bfloat16_t;\nusing f"
},
{
"path": "csrc/kerutils/include/kerutils/common/common.h",
"chars": 143,
"preview": "#pragma once\n\nnamespace kerutils {}\n\n#define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print(\"\\n\"); }\n\nnam"
},
{
"path": "csrc/kerutils/include/kerutils/device/common.h",
"chars": 1601,
"preview": "/*\nCommon data types and macros that are used across the kerutils library.\n*/\n#pragma once\n\n#include <cuda_bf16.h>\n#incl"
},
{
"path": "csrc/kerutils/include/kerutils/device/device.cuh",
"chars": 320,
"preview": "#pragma once\n\n#include \"kerutils/common/common.h\"\n\n#include \"common.h\"\n#include \"sm80/intrinsics.cuh\"\n#include \"sm80/hel"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm100/gemm.cuh",
"chars": 26103,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include <kerutils/device/common.h>\n\nnamespace cute {\n\n// Extensions to CuTe\n/"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm100/helpers.cuh",
"chars": 4394,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// Perform SS UTCMM"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh",
"chars": 19645,
"preview": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// tma gather4 (https://docs.nvidia.com/cuda/pa"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh",
"chars": 11349,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include <kerutils/device/common.h>\n\nnamespace cute {\n\n// Extensions to CuTe\n/"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm80/helpers.cuh",
"chars": 1216,
"preview": "#pragma once\n\n#include \"kerutils/device/common.h\"\n#include \"kerutils/device/sm80/intrinsics.cuh\"\n\nnamespace kerutils {\n\n"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh",
"chars": 7375,
"preview": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// cp.async.cg (cache global) with prefetch and"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm90/helpers.cuh",
"chars": 4574,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\ntemplate<\n typen"
},
{
"path": "csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh",
"chars": 4701,
"preview": "#pragma once\n\n#include \"kerutils/device/common.h\"\n\nnamespace kerutils {\n\n// st.async (https://docs.nvidia.com/cuda/paral"
},
{
"path": "csrc/kerutils/include/kerutils/host/host.h",
"chars": 6558,
"preview": "#pragma once\n\n#include <exception>\n#include <string>\n#include <sstream>\n#include <vector>\n\n#include <cuda_runtime_api.h>"
},
{
"path": "csrc/kerutils/include/kerutils/kerutils.cuh",
"chars": 66,
"preview": "#pragma once\n\n#include \"host/host.h\"\n#include \"device/device.cuh\"\n"
},
{
"path": "csrc/kerutils/include/kerutils/supplemental/torch_tensors.h",
"chars": 3283,
"preview": "#pragma once\n\n#include <functional>\n\n#include <torch/python.h>\n\n#include \"kerutils/common/common.h\"\n\nnamespace kerutils "
},
{
"path": "csrc/params.h",
"chars": 6266,
"preview": "#pragma once\n\n#include \"cutlass/bfloat16.h\"\n\nenum class ModelType {\n V32,\n MODEL1\n};\n\nstruct __align__(4*8) Decodi"
},
{
"path": "csrc/sm100/decode/head128/README.md",
"chars": 185,
"preview": "Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_deco"
},
{
"path": "csrc/sm100/decode/head64/config.h",
"chars": 7716,
"preview": "#pragma once\n\n#include \"kernel.h\"\n\n#include <cuda_fp8.h>\n#include <cutlass/barrier.h>\n#include <cute/tensor.hpp>\n\n#inclu"
},
{
"path": "csrc/sm100/decode/head64/instantiations/model1.cu",
"chars": 176,
"preview": "#include \"../kernel.cuh\"\n\nnamespace sm100::decode::head64 {\n\ntemplate\nvoid run_flash_splitkv_mla_fp8_sparse_kernel<Model"
},
{
"path": "csrc/sm100/decode/head64/instantiations/v32.cu",
"chars": 173,
"preview": "#include \"../kernel.cuh\"\n\nnamespace sm100::decode::head64 {\n\ntemplate\nvoid run_flash_splitkv_mla_fp8_sparse_kernel<Model"
},
{
"path": "csrc/sm100/decode/head64/kernel.cuh",
"chars": 51887,
"preview": "#include \"kernel.h\"\n\n#include <math_constants.h>\n#include <cutlass/barrier.h>\n#include <cutlass/arch/barrier.h>\n#include"
},
{
"path": "csrc/sm100/decode/head64/kernel.h",
"chars": 189,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::decode::head64 {\n\ntemplate<ModelType MODEL_TYPE>\nvoid run_flash_spli"
},
{
"path": "csrc/sm100/helpers.h",
"chars": 717,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cuda_bf16.h>\n#include <cuda_fp8.h>\n\n#include \"defines.h\"\n\nnamespace s"
},
{
"path": "csrc/sm100/prefill/dense/collective/fmha_common.hpp",
"chars": 5110,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/fmha_fusion.hpp",
"chars": 13050,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp",
"chars": 8541,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp",
"chars": 45626,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp",
"chars": 11571,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp",
"chars": 46621,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp",
"chars": 12871,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/common/gather_tensor.hpp",
"chars": 7080,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/common/helper.h",
"chars": 3642,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/common/mask.cuh",
"chars": 132,
"preview": "#pragma once\n\nenum class MaskMode {\n kNone = 0U, // No mask\n kCausal = 1U, // Causal mask\n kCustom = 2U, // Cust"
},
{
"path": "csrc/sm100/prefill/dense/common/pipeline_mla.hpp",
"chars": 10204,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/common/pow_2.hpp",
"chars": 3380,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/common/utils.hpp",
"chars": 593,
"preview": "#pragma once\n\n#include <torch/extension.h>\n#include \"cutlass/numeric_types.h\"\n#include \"helper.h\"\n\ntemplate <typename T>"
},
{
"path": "csrc/sm100/prefill/dense/device/fmha.hpp",
"chars": 9837,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp",
"chars": 12737,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
"chars": 3832,
"preview": "#include \"interface.h\"\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_bf16.h>\n#include"
},
{
"path": "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh",
"chars": 9305,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"chars": 3699,
"preview": "#include \"interface.h\"\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_bf16.h>\n\n#includ"
},
{
"path": "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh",
"chars": 13832,
"preview": "#pragma once\n\n#include \"collective/fmha_fusion.hpp\"\n#include \"collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp"
},
{
"path": "csrc/sm100/prefill/dense/interface.h",
"chars": 876,
"preview": "#pragma once\n\n#include <ATen/Tensor.h>\n\nvoid FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tenso"
},
{
"path": "csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp",
"chars": 6960,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp",
"chars": 6706,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp",
"chars": 6967,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/fmha_options.hpp",
"chars": 2852,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp",
"chars": 5536,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp",
"chars": 77689,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp",
"chars": 78104,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp",
"chars": 26129,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/sm100/prefill/sparse/common_subroutine.h",
"chars": 6278,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\nnamespace sm100 {\n\n/*\nLoad K/V indices from g"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head128/config.h",
"chars": 4803,
"preview": "#pragma once\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"params."
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu",
"chars": 162,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate void run_fwd_phase1_kernel<51"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu",
"chars": 162,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate void run_fwd_phase1_kernel<57"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh",
"chars": 30554,
"preview": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launc"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head128/phase1.h",
"chars": 153,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd::head128 {\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const "
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head64/config.h",
"chars": 4871,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <kerutils/kerutils.cuh>\n\n#include \"defines.h\"\n\nnamespace sm100::fwd::h"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu",
"chars": 161,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate void run_fwd_phase1_kernel<512"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu",
"chars": 161,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate void run_fwd_phase1_kernel<576"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh",
"chars": 28868,
"preview": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/arch/reg_reco"
},
{
"path": "csrc/sm100/prefill/sparse/fwd/head64/phase1.h",
"chars": 152,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd::head64 {\n\ntemplate<int D_QK>\nvoid run_fwd_phase1_kernel(const S"
},
{
"path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h",
"chars": 5015,
"preview": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cutlass/float8.h>\n#include <cute/tensor.hpp>\n#in"
},
{
"path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu",
"chars": 233,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate void run_fwd_f"
},
{
"path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu",
"chars": 220,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate void run_fwd_f"
},
{
"path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh",
"chars": 56963,
"preview": "#pragma once\n#include \"phase1.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launc"
},
{
"path": "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h",
"chars": 215,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm100::fwd_for_small_topk::head128 {\n\ntemplate<SparseAttnFwdMode FWD_MODE, "
},
{
"path": "csrc/sm90/decode/dense/config.h",
"chars": 199,
"preview": "#pragma once\n\nnamespace Config {\n\nstatic constexpr int BLOCK_SIZE_M = 64;\nstatic constexpr int PAGE_BLOCK_SIZE = 64;\n\nst"
},
{
"path": "csrc/sm90/decode/dense/instantiations/bf16.cu",
"chars": 176,
"preview": "#include \"../splitkv_mla.cuh\"\n#include \"../splitkv_mla.h\"\n\nnamespace sm90 {\n\ntemplate void run_flash_splitkv_mla_kernel<"
},
{
"path": "csrc/sm90/decode/dense/instantiations/fp16.cu",
"chars": 210,
"preview": "#include \"../splitkv_mla.cuh\"\n#include \"../splitkv_mla.h\"\n\nnamespace sm90 {\n\n#ifndef FLASH_MLA_DISABLE_FP16\ntemplate voi"
},
{
"path": "csrc/sm90/decode/dense/splitkv_mla.cuh",
"chars": 58333,
"preview": "#include <cutlass/cutlass.h>\n\n#include \"utils.h\"\n\n#include \"params.h\"\n#include \"config.h\"\n#include \"traits.h\"\n\nusing nam"
},
{
"path": "csrc/sm90/decode/dense/splitkv_mla.h",
"chars": 148,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90 {\n\ntemplate<typename InputT>\nvoid run_flash_splitkv_mla_kernel(DenseAt"
},
{
"path": "csrc/sm90/decode/dense/traits.h",
"chars": 3644,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cutlass/cutlass.h>\n#include <cutlass/numeric_types.h>\n#include <cutla"
},
{
"path": "csrc/sm90/decode/sparse_fp8/components/config.h",
"chars": 652,
"preview": "#pragma once\n\n#include <cutlass/numeric_types.h>\n#include <cutlass/arch/barrier.h>\n#include <cute/tensor.hpp>\n#include \""
},
{
"path": "csrc/sm90/decode/sparse_fp8/components/dequant.h",
"chars": 3559,
"preview": "#pragma once\n\n#include <cuda_fp8.h>\n#include <cuda_bf16.h>\n\n#include \"defines.h\"\n\nnamespace sm90::decode::sparse_fp8 {\n\n"
},
{
"path": "csrc/sm90/decode/sparse_fp8/components/helpers.h",
"chars": 4300,
"preview": "#pragma once\n\n#include <cooperative_groups.h>\n#include <cute/tensor.hpp>\n\n#include \"config.h\"\n\nusing namespace cute;\n\nna"
},
{
"path": "csrc/sm90/decode/sparse_fp8/config.h",
"chars": 9285,
"preview": "#pragma once\n\n#include <cutlass/numeric_types.h>\n#include <cutlass/arch/barrier.h>\n#include <cute/tensor.hpp>\n#include <"
},
{
"path": "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"chars": 189,
"preview": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kern"
},
{
"path": "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"chars": 189,
"preview": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kern"
},
{
"path": "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
"chars": 186,
"preview": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kern"
},
{
"path": "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"chars": 185,
"preview": "#include \"../splitkv_mla.cuh\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate void run_flash_splitkv_mla_fp8_sparse_kern"
},
{
"path": "csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh",
"chars": 39360,
"preview": "#pragma once\n\n#include \"splitkv_mla.h\"\n\n#include <cuda_fp8.h>\n#include <math_constants.h>\n#include <cutlass/barrier.h>\n#"
},
{
"path": "csrc/sm90/decode/sparse_fp8/splitkv_mla.h",
"chars": 207,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90::decode::sparse_fp8 {\n\ntemplate<ModelType MODEL_TYPE, int NUM_HEADS>\nv"
},
{
"path": "csrc/sm90/helpers.h",
"chars": 6659,
"preview": "#pragma once\n\n#include <cute/tensor.hpp>\n#include <cutlass/arch/barrier.h>\n\nnamespace sm90 {\n\n__forceinline__ __device__"
},
{
"path": "csrc/sm90/prefill/sparse/config.h",
"chars": 4318,
"preview": "#pragma once\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cluster_launch.hpp>\n#include <coo"
},
{
"path": "csrc/sm90/prefill/sparse/fwd.cu",
"chars": 851,
"preview": "#include \"fwd.h\"\n\n#include <stdexcept>\n\n#include \"phase1.h\"\n\nnamespace sm90 {\n\nvoid run_fwd_kernel(const SparseAttnFwdPa"
},
{
"path": "csrc/sm90/prefill/sparse/fwd.h",
"chars": 112,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace sm90 {\n\nvoid run_fwd_kernel(const SparseAttnFwdParams& params);\n\n}\n"
},
{
"path": "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
"chars": 327,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\n// NOTE (intlsy): We instantiate run_fwd_phase1_"
},
{
"path": "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
"chars": 326,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\n// NOTE (intlsy): We instantiate run_fwd_phase1_"
},
{
"path": "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
"chars": 159,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\ntemplate void run_fwd_phase1_kernel<576, false>("
},
{
"path": "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",
"chars": 158,
"preview": "#include \"../phase1.h\"\n#include \"../phase1.cuh\"\n\nnamespace sm90::fwd {\n\ntemplate void run_fwd_phase1_kernel<576, true>(c"
},
{
"path": "csrc/sm90/prefill/sparse/phase1.cuh",
"chars": 27402,
"preview": "#pragma once\n\n#include \"config.h\"\n\n#include \"utils.h\"\n#include \"../../helpers.h\"\n\nnamespace sm90::fwd {\n\nusing namespace"
},
{
"path": "csrc/sm90/prefill/sparse/phase1.h",
"chars": 175,
"preview": "#pragma once\n\n#include \"../../../params.h\"\n\nnamespace sm90::fwd {\n\ntemplate<int D_QK, bool HAVE_TOPK_LENGTH>\nvoid run_fw"
},
{
"path": "csrc/smxx/decode/combine/combine.cu",
"chars": 9638,
"preview": "#include \"combine.h\"\n\n#include <math_constants.h>\n#include <cute/tensor.hpp>\n#include <cutlass/cutlass.h>\n#include <cutl"
},
{
"path": "csrc/smxx/decode/combine/combine.h",
"chars": 150,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace smxx::decode {\n\ntemplate<typename ElementT>\nvoid run_flash_mla_combine_kern"
},
{
"path": "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"chars": 5749,
"preview": "#include \"get_decoding_sched_meta.h\"\n\n#include <cuda_runtime_api.h>\n#include <cutlass/fast_math.h>\n#include <kerutils/ke"
},
{
"path": "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h",
"chars": 139,
"preview": "#pragma once\n\n#include \"params.h\"\n\nnamespace smxx::decode {\n\nvoid run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaP"
},
{
"path": "csrc/utils.h",
"chars": 3432,
"preview": "#pragma once\n\n#include <cstdint>\n\n#define CHECK_CUDA(call) "
},
{
"path": "docs/20250422-new-kernel-deep-dive.md",
"chars": 8349,
"preview": "# A Deep-Dive Into the New Flash MLA Kernel\n\nIn the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31b"
},
{
"path": "docs/20250929-hopper-fp8-sparse-deep-dive.md",
"chars": 7257,
"preview": "# A Deep Dive Into The Flash MLA FP8 Decoding Kernel on Hopper\n\nWith the release of DeepSeek-V3.2, we have doubled the c"
},
{
"path": "flash_mla/__init__.py",
"chars": 452,
"preview": "__version__ = \"1.0.0\"\n\nfrom flash_mla.flash_mla_interface import (\n get_mla_metadata,\n flash_mla_with_kvcache,\n "
},
{
"path": "flash_mla/flash_mla_interface.py",
"chars": 19102,
"preview": "from typing import Optional, Tuple\nimport dataclasses\n\nimport torch\n\nimport flash_mla.cuda as flash_mla_cuda\n\n@dataclass"
},
{
"path": "setup.py",
"chars": 6012,
"preview": "import os\nfrom pathlib import Path\nfrom datetime import datetime\nimport subprocess\n\nfrom setuptools import setup, find_p"
},
{
"path": "tests/kernelkit/.gitignore",
"chars": 74,
"preview": "build\n*.so\n*.egg-info/\n__pycache__/\ndist/\n/.vscode\n.cache\n/temp\n/profiles\n"
},
{
"path": "tests/kernelkit/__init__.py",
"chars": 590,
"preview": "from . import bench\nfrom . import compare\nfrom . import generate\nfrom . import precision\nfrom . import utils\n\nfrom .benc"
},
{
"path": "tests/kernelkit/bench.py",
"chars": 9366,
"preview": "from typing import Tuple, List, Callable, Union, Dict, overload\nimport dataclasses\n\nimport torch\nimport triton\n\nfrom .ut"
},
{
"path": "tests/kernelkit/compare.py",
"chars": 4421,
"preview": "from typing import List\n\nimport torch\n\ndef check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, resul"
},
{
"path": "tests/kernelkit/generate.py",
"chars": 1053,
"preview": "import torch\n\ndef _get_new_non_contiguous_tensor_shape(shape):\n \"\"\"\n Get the expanded shape for a non-contiguous t"
},
{
"path": "tests/kernelkit/precision.py",
"chars": 997,
"preview": "import torch\n\n_is_low_precision_mode_stack = []\n\nclass LowPrecisionMode:\n def __init__(self, enabled: bool = True):\n "
},
{
"path": "tests/kernelkit/utils.py",
"chars": 1384,
"preview": "import os\nimport functools\n\ncolors = {\n 'RED_FG': '\\033[31m',\n 'GREEN_FG': '\\033[32m',\n 'CYAN_FG': '\\033[36m',\n"
},
{
"path": "tests/lib.py",
"chars": 16677,
"preview": "import dataclasses\nimport os\nimport enum\nfrom typing import List, Optional\nimport random\n\nimport torch\nimport kernelkit "
},
{
"path": "tests/quant.py",
"chars": 8162,
"preview": "import enum\nfrom typing import Tuple\n\nimport torch\n\nclass FP8KVCacheLayout(enum.Enum):\n V32_FP8Sparse = 1\n MODEL1_"
},
{
"path": "tests/ref.py",
"chars": 4556,
"preview": "from typing import Optional, Tuple\n\nimport torch\n\nfrom lib import TestParam, Testcase, TestcaseForDecode, KVScope\n\ndef _"
},
{
"path": "tests/test_flash_mla_dense_decoding.py",
"chars": 9054,
"preview": "import argparse\nimport math\nimport random\nimport dataclasses\nfrom typing import Tuple\n\nimport torch\n\nimport kernelkit as"
},
{
"path": "tests/test_flash_mla_sparse_decoding.py",
"chars": 14317,
"preview": "import time\nimport dataclasses\nfrom typing import Tuple, List, Dict, Optional\nimport copy\n\nimport rich.console\nimport ri"
},
{
"path": "tests/test_flash_mla_sparse_prefill.py",
"chars": 5682,
"preview": "import time\nimport sys\n\nimport torch\nimport kernelkit as kk\n\nfrom lib import TestParam\nimport lib\nimport ref\n\n_counter ="
},
{
"path": "tests/test_fmha_sm100.py",
"chars": 7553,
"preview": "import random\n\nimport torch\nfrom torch.utils.checkpoint import checkpoint\nimport triton\n\nfrom flash_mla import flash_att"
}
]
About this extraction
This page contains the full source code of the deepseek-ai/FlashMLA GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 129 files (1.1 MB), approximately 313.1k tokens, and a symbol index with 484 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.