gitextract_clsc5nbn/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── benchmark/ │ ├── bench_flash_mla.py │ └── visualize.py ├── csrc/ │ ├── api/ │ │ ├── api.cpp │ │ ├── common.h │ │ ├── dense_decode.h │ │ ├── dense_fwd.h │ │ ├── sparse_decode.h │ │ └── sparse_fwd.h │ ├── defines.h │ ├── kerutils/ │ │ └── include/ │ │ └── kerutils/ │ │ ├── common/ │ │ │ └── common.h │ │ ├── device/ │ │ │ ├── common.h │ │ │ ├── device.cuh │ │ │ ├── sm100/ │ │ │ │ ├── gemm.cuh │ │ │ │ ├── helpers.cuh │ │ │ │ ├── intrinsics.cuh │ │ │ │ └── tma_cta_group2_nosplit.cuh │ │ │ ├── sm80/ │ │ │ │ ├── helpers.cuh │ │ │ │ └── intrinsics.cuh │ │ │ └── sm90/ │ │ │ ├── helpers.cuh │ │ │ └── intrinsics.cuh │ │ ├── host/ │ │ │ └── host.h │ │ ├── kerutils.cuh │ │ └── supplemental/ │ │ └── torch_tensors.h │ ├── params.h │ ├── sm100/ │ │ ├── decode/ │ │ │ ├── head128/ │ │ │ │ └── README.md │ │ │ └── head64/ │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── model1.cu │ │ │ │ └── v32.cu │ │ │ ├── kernel.cuh │ │ │ └── kernel.h │ │ ├── helpers.h │ │ └── prefill/ │ │ ├── dense/ │ │ │ ├── collective/ │ │ │ │ ├── fmha_common.hpp │ │ │ │ ├── fmha_fusion.hpp │ │ │ │ ├── sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_load_tma_warpspecialized.hpp │ │ │ │ ├── sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp │ │ │ │ └── sm100_fmha_mla_load_tma_warpspecialized.hpp │ │ │ ├── common/ │ │ │ │ ├── gather_tensor.hpp │ │ │ │ ├── helper.h │ │ │ │ ├── mask.cuh │ │ │ │ ├── pipeline_mla.hpp │ │ │ │ ├── pow_2.hpp │ │ │ │ └── utils.hpp │ │ │ ├── device/ │ │ │ │ ├── fmha.hpp │ │ │ │ └── fmha_device_bwd.hpp │ │ │ ├── fmha_cutlass_bwd_sm100.cu │ │ │ ├── fmha_cutlass_bwd_sm100.cuh │ │ │ ├── fmha_cutlass_fwd_sm100.cu │ │ │ ├── fmha_cutlass_fwd_sm100.cuh │ │ │ ├── interface.h │ │ │ └── kernel/ │ │ │ ├── fmha_causal_tile_scheduler.hpp │ │ │ ├── fmha_kernel_bwd_convert.hpp │ │ │ ├── fmha_kernel_bwd_sum_OdO.hpp │ │ │ ├── fmha_options.hpp │ │ │ ├── fmha_tile_scheduler.hpp │ │ │ ├── sm100_fmha_bwd_kernel_tma_warpspecialized.hpp │ │ │ ├── sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp │ │ │ └── sm100_fmha_fwd_kernel_tma_warpspecialized.hpp │ │ └── sparse/ │ │ ├── common_subroutine.h │ │ ├── fwd/ │ │ │ ├── head128/ │ │ │ │ ├── config.h │ │ │ │ ├── instantiations/ │ │ │ │ │ ├── phase1_k512.cu │ │ │ │ │ └── phase1_k576.cu │ │ │ │ ├── phase1.cuh │ │ │ │ └── phase1.h │ │ │ └── head64/ │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── phase1_k512.cu │ │ │ │ └── phase1_k576.cu │ │ │ ├── phase1.cuh │ │ │ └── phase1.h │ │ └── fwd_for_small_topk/ │ │ └── head128/ │ │ ├── config.h │ │ ├── instantiations/ │ │ │ ├── phase1_decode_k512.cu │ │ │ └── phase1_prefill_k512.cu │ │ ├── phase1.cuh │ │ └── phase1.h │ ├── sm90/ │ │ ├── decode/ │ │ │ ├── dense/ │ │ │ │ ├── config.h │ │ │ │ ├── instantiations/ │ │ │ │ │ ├── bf16.cu │ │ │ │ │ └── fp16.cu │ │ │ │ ├── splitkv_mla.cuh │ │ │ │ ├── splitkv_mla.h │ │ │ │ └── traits.h │ │ │ └── sparse_fp8/ │ │ │ ├── components/ │ │ │ │ ├── config.h │ │ │ │ ├── dequant.h │ │ │ │ └── helpers.h │ │ │ ├── config.h │ │ │ ├── instantiations/ │ │ │ │ ├── model1_persistent_h128.cu │ │ │ │ ├── model1_persistent_h64.cu │ │ │ │ ├── v32_persistent_h128.cu │ │ │ │ └── v32_persistent_h64.cu │ │ │ ├── splitkv_mla.cuh │ │ │ └── splitkv_mla.h │ │ ├── helpers.h │ │ └── prefill/ │ │ └── sparse/ │ │ ├── config.h │ │ ├── fwd.cu │ │ ├── fwd.h │ │ ├── instantiations/ │ │ │ ├── phase1_k512.cu │ │ │ ├── phase1_k512_topklen.cu │ │ │ ├── phase1_k576.cu │ │ │ └── phase1_k576_topklen.cu │ │ ├── phase1.cuh │ │ └── phase1.h │ ├── smxx/ │ │ └── decode/ │ │ ├── combine/ │ │ │ ├── combine.cu │ │ │ └── combine.h │ │ └── get_decoding_sched_meta/ │ │ ├── get_decoding_sched_meta.cu │ │ └── get_decoding_sched_meta.h │ └── utils.h ├── docs/ │ ├── 20250422-new-kernel-deep-dive.md │ └── 20250929-hopper-fp8-sparse-deep-dive.md ├── flash_mla/ │ ├── __init__.py │ └── flash_mla_interface.py ├── setup.py └── tests/ ├── kernelkit/ │ ├── .gitignore │ ├── __init__.py │ ├── bench.py │ ├── compare.py │ ├── generate.py │ ├── precision.py │ └── utils.py ├── lib.py ├── quant.py ├── ref.py ├── test_flash_mla_dense_decoding.py ├── test_flash_mla_sparse_decoding.py ├── test_flash_mla_sparse_prefill.py └── test_fmha_sm100.py