Showing preview only (5,503K chars total). Download the full file or copy to clipboard to get everything.
Repository: Dao-AILab/flash-attention
Branch: main
Commit: dd6f4a212a4f
Files: 994
Total size: 5.1 MB
Directory structure:
gitextract_wgwkfssl/
├── .github/
│ └── workflows/
│ ├── README.md
│ ├── _build.yml
│ ├── build.yml
│ ├── pre-commit.yaml
│ ├── publish-fa4.yml
│ └── publish.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── AI/
│ ├── DEBUG_2CTA.md
│ ├── RACECHECK_TMA_HAZARD.md
│ ├── SM90_BLOCK_SIZE_TUNING.md
│ ├── SM90_R2P_MASKING_SASS.md
│ ├── VARLEN_PREPROCESS_TILE_BUG.md
│ ├── racecheck_repro_1d_bulk.py
│ └── racecheck_repro_1d_tensor.py
├── AUTHORS
├── CLAUDE.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── benchmarks/
│ ├── bench_sm90.py
│ ├── benchmark_alibi.py
│ ├── benchmark_attn.py
│ ├── benchmark_causal.py
│ ├── benchmark_flash_attention.py
│ └── benchmark_gemm.py
├── csrc/
│ ├── flash_attn/
│ │ ├── flash_api.cpp
│ │ └── src/
│ │ ├── alibi.h
│ │ ├── block_info.h
│ │ ├── dropout.h
│ │ ├── flash.h
│ │ ├── flash_bwd_hdim128_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim128_bf16_sm80.cu
│ │ ├── flash_bwd_hdim128_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim128_fp16_sm80.cu
│ │ ├── flash_bwd_hdim192_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim192_bf16_sm80.cu
│ │ ├── flash_bwd_hdim192_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim192_fp16_sm80.cu
│ │ ├── flash_bwd_hdim256_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim256_bf16_sm80.cu
│ │ ├── flash_bwd_hdim256_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim256_fp16_sm80.cu
│ │ ├── flash_bwd_hdim32_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim32_bf16_sm80.cu
│ │ ├── flash_bwd_hdim32_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim32_fp16_sm80.cu
│ │ ├── flash_bwd_hdim64_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim64_bf16_sm80.cu
│ │ ├── flash_bwd_hdim64_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim64_fp16_sm80.cu
│ │ ├── flash_bwd_hdim96_bf16_causal_sm80.cu
│ │ ├── flash_bwd_hdim96_bf16_sm80.cu
│ │ ├── flash_bwd_hdim96_fp16_causal_sm80.cu
│ │ ├── flash_bwd_hdim96_fp16_sm80.cu
│ │ ├── flash_bwd_kernel.h
│ │ ├── flash_bwd_launch_template.h
│ │ ├── flash_bwd_preprocess_kernel.h
│ │ ├── flash_fwd_hdim128_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_sm80.cu
│ │ ├── flash_fwd_hdim32_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim32_bf16_sm80.cu
│ │ ├── flash_fwd_hdim32_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim32_fp16_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_causal_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_causal_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_sm80.cu
│ │ ├── flash_fwd_kernel.h
│ │ ├── flash_fwd_launch_template.h
│ │ ├── flash_fwd_split_hdim128_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim128_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim128_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim128_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim192_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim192_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim192_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim192_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim256_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim256_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim256_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim256_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim32_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim32_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim32_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim32_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim64_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim64_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim64_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim64_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim96_bf16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim96_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim96_fp16_causal_sm80.cu
│ │ ├── flash_fwd_split_hdim96_fp16_sm80.cu
│ │ ├── generate_kernels.py
│ │ ├── hardware_info.h
│ │ ├── kernel_traits.h
│ │ ├── mask.h
│ │ ├── namespace_config.h
│ │ ├── philox.cuh
│ │ ├── philox_unpack.cuh
│ │ ├── rotary.h
│ │ ├── softmax.h
│ │ ├── static_switch.h
│ │ └── utils.h
│ ├── flash_attn_ck/
│ │ ├── flash_api.cpp
│ │ ├── flash_common.cpp
│ │ ├── flash_common.hpp
│ │ ├── mha_bwd.cpp
│ │ ├── mha_fwd.cpp
│ │ ├── mha_fwd_kvcache.cpp
│ │ ├── mha_varlen_bwd.cpp
│ │ └── mha_varlen_fwd.cpp
│ ├── fused_dense_lib/
│ │ ├── README.md
│ │ ├── fused_dense.cpp
│ │ ├── fused_dense_cuda.cu
│ │ └── setup.py
│ └── layer_norm/
│ ├── README.md
│ ├── ln.h
│ ├── ln_api.cpp
│ ├── ln_bwd_1024.cu
│ ├── ln_bwd_1280.cu
│ ├── ln_bwd_1536.cu
│ ├── ln_bwd_2048.cu
│ ├── ln_bwd_256.cu
│ ├── ln_bwd_2560.cu
│ ├── ln_bwd_3072.cu
│ ├── ln_bwd_4096.cu
│ ├── ln_bwd_512.cu
│ ├── ln_bwd_5120.cu
│ ├── ln_bwd_6144.cu
│ ├── ln_bwd_7168.cu
│ ├── ln_bwd_768.cu
│ ├── ln_bwd_8192.cu
│ ├── ln_bwd_kernels.cuh
│ ├── ln_fwd_1024.cu
│ ├── ln_fwd_1280.cu
│ ├── ln_fwd_1536.cu
│ ├── ln_fwd_2048.cu
│ ├── ln_fwd_256.cu
│ ├── ln_fwd_2560.cu
│ ├── ln_fwd_3072.cu
│ ├── ln_fwd_4096.cu
│ ├── ln_fwd_512.cu
│ ├── ln_fwd_5120.cu
│ ├── ln_fwd_6144.cu
│ ├── ln_fwd_7168.cu
│ ├── ln_fwd_768.cu
│ ├── ln_fwd_8192.cu
│ ├── ln_fwd_kernels.cuh
│ ├── ln_kernel_traits.h
│ ├── ln_parallel_bwd_1024.cu
│ ├── ln_parallel_bwd_1280.cu
│ ├── ln_parallel_bwd_1536.cu
│ ├── ln_parallel_bwd_2048.cu
│ ├── ln_parallel_bwd_256.cu
│ ├── ln_parallel_bwd_2560.cu
│ ├── ln_parallel_bwd_3072.cu
│ ├── ln_parallel_bwd_4096.cu
│ ├── ln_parallel_bwd_512.cu
│ ├── ln_parallel_bwd_5120.cu
│ ├── ln_parallel_bwd_6144.cu
│ ├── ln_parallel_bwd_7168.cu
│ ├── ln_parallel_bwd_768.cu
│ ├── ln_parallel_bwd_8192.cu
│ ├── ln_parallel_fwd_1024.cu
│ ├── ln_parallel_fwd_1280.cu
│ ├── ln_parallel_fwd_1536.cu
│ ├── ln_parallel_fwd_2048.cu
│ ├── ln_parallel_fwd_256.cu
│ ├── ln_parallel_fwd_2560.cu
│ ├── ln_parallel_fwd_3072.cu
│ ├── ln_parallel_fwd_4096.cu
│ ├── ln_parallel_fwd_512.cu
│ ├── ln_parallel_fwd_5120.cu
│ ├── ln_parallel_fwd_6144.cu
│ ├── ln_parallel_fwd_7168.cu
│ ├── ln_parallel_fwd_768.cu
│ ├── ln_parallel_fwd_8192.cu
│ ├── ln_parallel_residual_bwd_kernels.cuh
│ ├── ln_parallel_residual_fwd_kernels.cuh
│ ├── ln_utils.cuh
│ ├── setup.py
│ └── static_switch.h
├── examples/
│ └── inference/
│ └── README.md
├── flash_attn/
│ ├── __init__.py
│ ├── bert_padding.py
│ ├── cute/
│ │ ├── .flake8
│ │ ├── AUTHORS
│ │ ├── LICENSE
│ │ ├── MANIFEST.in
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── ampere_helpers.py
│ │ ├── barrier.py
│ │ ├── bench_utils.py
│ │ ├── benchmark.py
│ │ ├── blackwell_helpers.py
│ │ ├── block_info.py
│ │ ├── block_sparse_utils.py
│ │ ├── block_sparsity.py
│ │ ├── cache_utils.py
│ │ ├── compute_block_sparsity.py
│ │ ├── copy_utils.py
│ │ ├── cute_dsl_ptxas.py
│ │ ├── cute_dsl_utils.py
│ │ ├── fa_logging.py
│ │ ├── fast_math.py
│ │ ├── flash_bwd.py
│ │ ├── flash_bwd_postprocess.py
│ │ ├── flash_bwd_preprocess.py
│ │ ├── flash_bwd_sm100.py
│ │ ├── flash_bwd_sm120.py
│ │ ├── flash_bwd_sm90.py
│ │ ├── flash_fwd.py
│ │ ├── flash_fwd_combine.py
│ │ ├── flash_fwd_sm100.py
│ │ ├── flash_fwd_sm120.py
│ │ ├── flash_fwd_sm90.py
│ │ ├── interface.py
│ │ ├── mask.py
│ │ ├── mma_sm100_desc.py
│ │ ├── named_barrier.py
│ │ ├── pack_gqa.py
│ │ ├── paged_kv.py
│ │ ├── pipeline.py
│ │ ├── pyproject.toml
│ │ ├── seqlen_info.py
│ │ ├── sm90_config_search.py
│ │ ├── softmax.py
│ │ ├── testing.py
│ │ ├── tile_scheduler.py
│ │ └── utils.py
│ ├── flash_attn_interface.py
│ ├── flash_attn_triton.py
│ ├── flash_attn_triton_og.py
│ ├── flash_blocksparse_attention.py
│ ├── flash_blocksparse_attn_interface.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── patch_embed.py
│ │ └── rotary.py
│ ├── losses/
│ │ ├── __init__.py
│ │ └── cross_entropy.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── baichuan.py
│ │ ├── bert.py
│ │ ├── bigcode.py
│ │ ├── btlm.py
│ │ ├── falcon.py
│ │ ├── gpt.py
│ │ ├── gpt_neox.py
│ │ ├── gptj.py
│ │ ├── llama.py
│ │ ├── opt.py
│ │ └── vit.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── embedding.py
│ │ ├── mha.py
│ │ └── mlp.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── fused_dense.py
│ │ ├── layer_norm.py
│ │ ├── rms_norm.py
│ │ └── triton/
│ │ ├── __init__.py
│ │ ├── cross_entropy.py
│ │ ├── k_activations.py
│ │ ├── layer_norm.py
│ │ ├── linear.py
│ │ ├── mlp.py
│ │ └── rotary.py
│ ├── pyproject.toml
│ └── utils/
│ ├── __init__.py
│ ├── benchmark.py
│ ├── distributed.py
│ ├── generation.py
│ ├── library.py
│ ├── pretrained.py
│ ├── testing.py
│ └── torch.py
├── hopper/
│ ├── __init__.py
│ ├── benchmark_attn.py
│ ├── benchmark_flash_attention_fp8.py
│ ├── benchmark_mla_decode.py
│ ├── benchmark_split_kv.py
│ ├── block.h
│ ├── copy_sm90_bulk_reduce.hpp
│ ├── cuda_check.h
│ ├── epilogue_bwd.hpp
│ ├── epilogue_fwd.hpp
│ ├── flash.h
│ ├── flash_api.cpp
│ ├── flash_api_stable.cpp
│ ├── flash_attn_interface.py
│ ├── flash_bwd_kernel_sm80.h
│ ├── flash_bwd_kernel_sm90.h
│ ├── flash_bwd_launch_template.h
│ ├── flash_bwd_postprocess_kernel.h
│ ├── flash_bwd_preprocess_kernel.h
│ ├── flash_fwd_combine.cu
│ ├── flash_fwd_combine_kernel.h
│ ├── flash_fwd_combine_launch_template.h
│ ├── flash_fwd_kernel_sm80.h
│ ├── flash_fwd_kernel_sm90.h
│ ├── flash_fwd_launch_template.h
│ ├── flash_prepare_scheduler.cu
│ ├── generate_kernels.py
│ ├── heuristics.h
│ ├── instantiations/
│ │ ├── flash_bwd_hdim128_bf16_sm80.cu
│ │ ├── flash_bwd_hdim128_bf16_sm90.cu
│ │ ├── flash_bwd_hdim128_bf16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim128_bf16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim128_bf16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim128_fp16_sm80.cu
│ │ ├── flash_bwd_hdim128_fp16_sm90.cu
│ │ ├── flash_bwd_hdim128_fp16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim128_fp16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim128_fp16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim192_bf16_sm80.cu
│ │ ├── flash_bwd_hdim192_bf16_sm90.cu
│ │ ├── flash_bwd_hdim192_bf16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim192_bf16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim192_bf16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim192_fp16_sm80.cu
│ │ ├── flash_bwd_hdim192_fp16_sm90.cu
│ │ ├── flash_bwd_hdim192_fp16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim192_fp16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim192_fp16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim256_bf16_sm80.cu
│ │ ├── flash_bwd_hdim256_bf16_sm90.cu
│ │ ├── flash_bwd_hdim256_bf16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim256_bf16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim256_bf16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim256_fp16_sm80.cu
│ │ ├── flash_bwd_hdim256_fp16_sm90.cu
│ │ ├── flash_bwd_hdim256_fp16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim256_fp16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim256_fp16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim64_bf16_sm80.cu
│ │ ├── flash_bwd_hdim64_bf16_sm90.cu
│ │ ├── flash_bwd_hdim64_bf16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim64_bf16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim64_bf16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim64_fp16_sm80.cu
│ │ ├── flash_bwd_hdim64_fp16_sm90.cu
│ │ ├── flash_bwd_hdim64_fp16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim64_fp16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim64_fp16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim96_bf16_sm80.cu
│ │ ├── flash_bwd_hdim96_bf16_sm90.cu
│ │ ├── flash_bwd_hdim96_bf16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim96_bf16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim96_bf16_softcapall_sm90.cu
│ │ ├── flash_bwd_hdim96_fp16_sm80.cu
│ │ ├── flash_bwd_hdim96_fp16_sm90.cu
│ │ ├── flash_bwd_hdim96_fp16_softcap_sm80.cu
│ │ ├── flash_bwd_hdim96_fp16_softcap_sm90.cu
│ │ ├── flash_bwd_hdim96_fp16_softcapall_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_sm100.cu
│ │ ├── flash_fwd_hdim128_bf16_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_split_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_bf16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim128_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_split_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim128_fp16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_128_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_split_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_bf16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim192_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_split_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim192_fp16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_split_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_bf16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim256_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_split_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim256_fp16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_256_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_split_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_bf16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim64_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_split_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim64_fp16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_split_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_bf16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdim96_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_split_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_softcapall_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_split_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm80.cu
│ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdim96_fp16_split_softcapall_sm80.cu
│ │ ├── flash_fwd_hdimall_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdimall_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdimall_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_split_sm90.cu
│ │ ├── flash_fwd_hdimall_fp16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_paged_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_split_sm90.cu
│ │ ├── flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_paged_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_split_sm90.cu
│ │ ├── flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_paged_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_paged_split_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_softcap_sm90.cu
│ │ ├── flash_fwd_hdimdiff_fp16_split_sm90.cu
│ │ └── flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu
│ ├── mainloop_bwd_sm80.hpp
│ ├── mainloop_bwd_sm90_tma_gmma_ws.hpp
│ ├── mainloop_fwd_sm80.hpp
│ ├── mainloop_fwd_sm90_tma_gmma_ws.hpp
│ ├── mask.h
│ ├── named_barrier.hpp
│ ├── pack_gqa.h
│ ├── padding.py
│ ├── paged_kv.h
│ ├── rotary.h
│ ├── seqlen.h
│ ├── setup.py
│ ├── sm90_pipeline_no_cluster.hpp
│ ├── softmax.h
│ ├── static_switch.h
│ ├── test_attn_kvcache.py
│ ├── test_flash_attn.py
│ ├── test_flash_attn_bwd_determinism.py
│ ├── test_flash_attn_triton_amd.py
│ ├── test_kvcache.py
│ ├── test_torch_compile_and_export.py
│ ├── test_util.py
│ ├── tile_scheduler.hpp
│ ├── tile_size.h
│ └── utils.h
├── setup.py
├── tests/
│ ├── cute/
│ │ ├── benchmark_block_sparsity.py
│ │ ├── benchmark_mask_mod.py
│ │ ├── conftest.py
│ │ ├── mask_mod_definitions.py
│ │ ├── score_mod_definitions.py
│ │ ├── test_block_sparsity.py
│ │ ├── test_flash_attn.py
│ │ ├── test_flash_attn_combine.py
│ │ ├── test_flash_attn_fast.py
│ │ ├── test_flash_attn_race_condition.py
│ │ ├── test_flash_attn_varlen.py
│ │ ├── test_mask_mod.py
│ │ ├── test_score_mod.py
│ │ ├── test_score_mod_varlen.py
│ │ └── test_utils.py
│ ├── layers/
│ │ └── test_rotary.py
│ ├── losses/
│ │ ├── test_cross_entropy.py
│ │ └── test_cross_entropy_parallel.py
│ ├── models/
│ │ ├── test_baichuan.py
│ │ ├── test_bert.py
│ │ ├── test_bigcode.py
│ │ ├── test_btlm.py
│ │ ├── test_falcon.py
│ │ ├── test_gpt.py
│ │ ├── test_gpt_generation_parallel.py
│ │ ├── test_gpt_neox.py
│ │ ├── test_gpt_parallel.py
│ │ ├── test_gptj.py
│ │ ├── test_llama.py
│ │ ├── test_opt.py
│ │ └── test_vit.py
│ ├── modules/
│ │ ├── test_block_parallel.py
│ │ ├── test_embedding_parallel.py
│ │ ├── test_mha_parallel.py
│ │ └── test_mlp_parallel.py
│ ├── ops/
│ │ ├── test_dropout_layer_norm.py
│ │ ├── test_fused_dense.py
│ │ ├── test_fused_dense_parallel.py
│ │ └── triton/
│ │ └── test_layer_norm.py
│ ├── pyproject.toml
│ ├── test_flash_attn.py
│ ├── test_flash_attn_ck.py
│ ├── test_flash_attn_triton_amd.py
│ ├── test_rotary.py
│ └── test_util.py
├── tools/
│ └── sass_diff.py
├── training/
│ ├── Dockerfile
│ ├── README.md
│ ├── configs/
│ │ ├── callbacks/
│ │ │ ├── causality-monitor.yaml
│ │ │ ├── default.yaml
│ │ │ ├── ema.yaml
│ │ │ ├── flop-count.yaml
│ │ │ ├── gpu-monitor.yaml
│ │ │ ├── model-summary.yaml
│ │ │ ├── none.yaml
│ │ │ ├── norm-monitor.yaml
│ │ │ ├── params-log.yaml
│ │ │ └── wandb.yaml
│ │ ├── config.yaml
│ │ ├── datamodule/
│ │ │ ├── openwebtext.yaml
│ │ │ └── thepile.yaml
│ │ ├── experiment/
│ │ │ ├── owt/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── gpt2l-flash.yaml
│ │ │ │ ├── gpt2l-hf.yaml
│ │ │ │ ├── gpt2l.yaml
│ │ │ │ ├── gpt2m-flash.yaml
│ │ │ │ ├── gpt2m-hf.yaml
│ │ │ │ ├── gpt2m.yaml
│ │ │ │ ├── gpt2s-flash.yaml
│ │ │ │ ├── gpt2s-hf.yaml
│ │ │ │ ├── gpt2s.yaml
│ │ │ │ ├── gpt2xl-flash.yaml
│ │ │ │ ├── gpt2xl-hf.yaml
│ │ │ │ └── gpt2xl.yaml
│ │ │ └── pile/
│ │ │ ├── base.yaml
│ │ │ ├── gpt3-2.7B-flash-8k.yaml
│ │ │ ├── gpt3-2.7B-flash-hdim128-rotary-8k.yaml
│ │ │ ├── gpt3-2.7B-flash-hdim128-rotary.yaml
│ │ │ ├── gpt3-2.7B-flash-hdim128.yaml
│ │ │ ├── gpt3-2.7B-flash-rotary-8k.yaml
│ │ │ ├── gpt3-2.7B-flash-rotary.yaml
│ │ │ ├── gpt3-2.7B-flash.yaml
│ │ │ ├── gpt3-2.7B-hf-hdim128.yaml
│ │ │ ├── gpt3-2.7B-hf.yaml
│ │ │ ├── gpt3l-flash-8k.yaml
│ │ │ ├── gpt3l-flash-rotary-30B.yaml
│ │ │ ├── gpt3l-flash-rotary-8k.yaml
│ │ │ ├── gpt3l-flash-rotary.yaml
│ │ │ ├── gpt3l-flash.yaml
│ │ │ ├── gpt3l-hf.yaml
│ │ │ ├── gpt3m-flash-8k.yaml
│ │ │ ├── gpt3m-flash-rotary-30B.yaml
│ │ │ ├── gpt3m-flash-rotary-8k.yaml
│ │ │ ├── gpt3m-flash-rotary.yaml
│ │ │ ├── gpt3m-flash.yaml
│ │ │ ├── gpt3m-hf.yaml
│ │ │ ├── gpt3s-flash-8k.yaml
│ │ │ ├── gpt3s-flash-rotary-30B.yaml
│ │ │ ├── gpt3s-flash-rotary-8k.yaml
│ │ │ ├── gpt3s-flash-rotary.yaml
│ │ │ ├── gpt3s-flash.yaml
│ │ │ ├── gpt3s-hf.yaml
│ │ │ ├── gpt3xl-flash-8k.yaml
│ │ │ ├── gpt3xl-flash-rotary-60B.yaml
│ │ │ ├── gpt3xl-flash-rotary-8k.yaml
│ │ │ ├── gpt3xl-flash-rotary.yaml
│ │ │ ├── gpt3xl-flash.yaml
│ │ │ └── gpt3xl-hf.yaml
│ │ ├── logger/
│ │ │ ├── comet.yaml
│ │ │ ├── csv.yaml
│ │ │ ├── many_loggers.yaml
│ │ │ ├── mlflow.yaml
│ │ │ ├── neptune.yaml
│ │ │ ├── tensorboard.yaml
│ │ │ └── wandb.yaml
│ │ ├── metrics/
│ │ │ ├── acc.yaml
│ │ │ ├── acc_ignore_index.yaml
│ │ │ ├── acctop5.yaml
│ │ │ ├── mse.yaml
│ │ │ ├── num-tokens.yaml
│ │ │ └── perplexity.yaml
│ │ ├── mode/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── exp.yaml
│ │ │ ├── profile.yaml
│ │ │ └── smoke.yaml
│ │ ├── model/
│ │ │ ├── gpt2-hf.yaml
│ │ │ ├── gpt2.yaml
│ │ │ └── gpt2model/
│ │ │ ├── gpt2-large.yaml
│ │ │ ├── gpt2-medium.yaml
│ │ │ ├── gpt2-small.yaml
│ │ │ └── gpt2-xlarge.yaml
│ │ ├── optimizer/
│ │ │ ├── adam.yaml
│ │ │ ├── adamw-apex-distributed.yaml
│ │ │ ├── adamw-apex-zero.yaml
│ │ │ ├── adamw-apex.yaml
│ │ │ ├── adamw-zero.yaml
│ │ │ ├── adamw.yaml
│ │ │ ├── fusedlamb-ds.yaml
│ │ │ ├── fusedlamb.yaml
│ │ │ └── sgd.yaml
│ │ ├── scheduler/
│ │ │ ├── cosine-warmup-timm.yaml
│ │ │ ├── cosine-warmup.yaml
│ │ │ ├── invsqrt.yaml
│ │ │ ├── linear-warmup.yaml
│ │ │ ├── multi-step.yaml
│ │ │ ├── plateau.yaml
│ │ │ ├── poly-warmup.yaml
│ │ │ └── step.yaml
│ │ ├── task/
│ │ │ └── sequence-model.yaml
│ │ └── trainer/
│ │ ├── all_params.yaml
│ │ ├── ddp.yaml
│ │ ├── debug.yaml
│ │ └── default.yaml
│ ├── run.py
│ ├── src/
│ │ ├── callbacks/
│ │ │ ├── __init__.py
│ │ │ ├── causality_monitor.py
│ │ │ ├── ema.py
│ │ │ ├── flop_count.py
│ │ │ ├── gpu_affinity.py
│ │ │ ├── loss_scale_monitor.py
│ │ │ ├── model_checkpoint.py
│ │ │ ├── norm_monitor.py
│ │ │ ├── params_log.py
│ │ │ ├── speed_monitor.py
│ │ │ └── wandb_callbacks.py
│ │ ├── datamodules/
│ │ │ ├── datasets/
│ │ │ │ ├── detokenizer.py
│ │ │ │ └── lm_dataset.py
│ │ │ ├── fault_tolerant_sampler.py
│ │ │ ├── imagenet.py
│ │ │ ├── language_modeling_hf.py
│ │ │ └── timm_mixup.py
│ │ ├── distributed/
│ │ │ └── ddp_comm_hooks.py
│ │ ├── eval.py
│ │ ├── metrics/
│ │ │ ├── accuracy.py
│ │ │ ├── num_tokens.py
│ │ │ └── perplexity.py
│ │ ├── models/
│ │ │ └── modules/
│ │ │ └── seq_common.py
│ │ ├── optim/
│ │ │ ├── param_grouping.py
│ │ │ └── timm_lr_scheduler.py
│ │ ├── tasks/
│ │ │ └── seq.py
│ │ ├── train.py
│ │ └── utils/
│ │ ├── checkpoint.py
│ │ ├── ddp_zero1.py
│ │ ├── ddp_zero2.py
│ │ ├── distributed.py
│ │ ├── ema.py
│ │ ├── flops.py
│ │ ├── gpu_affinity.py
│ │ └── utils.py
│ └── tests/
│ └── datamodules/
│ └── test_language_modeling_hf.py
└── usage.md
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/README.md
================================================
# GitHub Workflow Tagging Flow
This repository uses separate tag lanes so FA2 and FA4 publishing do not collide.
## Release lanes
| Tag pattern | Workflow | Package target | Version source |
| --- | --- | --- | --- |
| `v*` | `.github/workflows/publish.yml` | Root package (`flash-attn`) | Root package version metadata |
| `fa4-v*` | `.github/workflows/publish-fa4.yml` | `flash_attn/cute` package (`flash-attn-4`) | `setuptools-scm` with `fa4-v*` tags |
## How to publish
### FA2 / root package lane
1. Create a tag matching `v*` (example: `v2.9.0`).
2. Push that tag.
3. `publish.yml` creates a release, builds wheel matrix artifacts, and publishes to PyPI.
### FA4 / CUTE package lane
1. Create a tag matching `fa4-v*` (example: `fa4-v0.1.0`).
2. Push that tag.
3. `publish-fa4.yml` builds from `flash_attn/cute`, creates a GitHub release, and uploads `flash-attn-4` to PyPI.
## Guardrails
- Do not use `v*` tags for FA4 releases.
- Do not use `fa4-v*` tags for FA2 releases.
- Keep `flash_attn/cute/pyproject.toml` tag parsing in sync with the FA4 tag prefix.
- The workflow filename (`publish-fa4.yml`) is part of the PyPI trusted publishing OIDC identity — do not rename without updating PyPI.
================================================
FILE: .github/workflows/_build.yml
================================================
name: ~Build wheel template
on:
workflow_call:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "The C++11 ABI to use for the build"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string
defaults:
run:
shell: bash -x -e -u -o pipefail {0}
jobs:
build-wheel:
runs-on: ${{ inputs.runs-on }}
name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})
steps:
- name: Checkout
uses: actions/checkout@v5
with:
ref: ${{ inputs.release-version }}
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ inputs.cuda-version }}
if: ${{ inputs.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.30
id: cuda-toolkit
with:
cuda: ${{ inputs.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}
method: "network"
sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}
run: |
pip install --upgrade pip
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
available = { \
'2.6': [118, 124, 126], \
'2.7': [118, 126, 128], \
'2.8': [126, 128, 129], \
'2.9': [126, 128, 130], \
'2.10': [126, 128, 130], \
}[env['MATRIX_TORCH_VERSION']]; \
sys_cuda = int(env['MATRIX_CUDA_VERSION']); \
print(max(v for v in available if v <= sys_cuda))" \
)
# detect if we're on ARM
if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
PLAT=linux_aarch64
else
PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64
fi
echo "PLAT=$PLAT" >> $GITHUB_ENV
if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904
pip install jinja2
TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl
TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl
pip install --no-cache-dir --pre "${TRITON_URL}"
pip install --no-cache-dir --pre "${TORCH_URL}"
else
pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
- name: Restore build cache
uses: actions/cache/restore@v4
with:
path: build.tar
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
restore-keys: |
build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-
- name: Unpack build cache
run: |
echo ::group::Adjust timestamps
sudo find / -exec touch -t 197001010000 {} + || true
echo ::endgroup::
if [ -f build.tar ]; then
find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} +
tar -xpvf build.tar -C .
else
echo "No build.tar found, skipping"
fi
ls -al ./
ls -al build/ || true
ls -al csrc/ || true
- name: Build wheel
id: build_wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==75.8.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] || [ "$MATRIX_CUDA_VERSION" == "130" ] && echo 1 || echo 2)
export NVCC_THREADS=2
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}
# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
fi
# Store exit code in GitHub env for later steps
echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT"
# Do not fail the job if timeout killed the build
exit $EXIT_CODE
- name: Log build logs after timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
run: |
ls -al ./
tar -cvf build.tar . --atime-preserve=replace
- name: Save build cache timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
uses: actions/cache/save@v4
with:
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
path: build.tar
- name: Log Built Wheels
run: |
ls dist
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ inputs.release-version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
if: inputs.upload-to-release
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
================================================
FILE: .github/workflows/build.yml
================================================
name: Build wheels
on:
workflow_dispatch:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
default: ubuntu-22.04
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "Enable torch flag C++11 ABI (TRUE/FALSE)"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string
jobs:
build-wheels:
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ inputs.runs-on }}
python-version: ${{ inputs.python-version }}
cuda-version: ${{ inputs.cuda-version }}
torch-version: ${{ inputs.torch-version }}
cxx11_abi: ${{ inputs.cxx11_abi }}
upload-to-release: ${{ inputs.upload-to-release }}
release-version: ${{ inputs.release-version }}
================================================
FILE: .github/workflows/pre-commit.yaml
================================================
name: Lint
on:
pull_request:
paths:
- 'flash_attn/cute/flash_bwd_sm90.py'
- 'flash_attn/cute/flash_bwd_preprocess.py'
- 'flash_attn/cute/flash_bwd_postprocess.py'
- 'flash_attn/cute/softmax.py'
- '.pre-commit-config.yaml'
push:
branches:
- main
paths:
- 'flash_attn/cute/flash_bwd_sm90.py'
- 'flash_attn/cute/flash_bwd_preprocess.py'
- 'flash_attn/cute/flash_bwd_postprocess.py'
- 'flash_attn/cute/softmax.py'
- '.pre-commit-config.yaml'
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Run pre-commit
uses: pre-commit/action@v3.0.1
================================================
FILE: .github/workflows/publish-fa4.yml
================================================
name: Publish flash-attn-4 to PyPI
on:
push:
tags:
- 'fa4-v*'
permissions:
contents: write
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install build dependencies
run: pip install build twine
- name: Build package
run: python -m build
working-directory: flash_attn/cute
- name: Check package metadata
run: twine check dist/*
working-directory: flash_attn/cute
- name: Store distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: flash_attn/cute/dist/
github-release:
needs: build
runs-on: ubuntu-latest
steps:
- name: Download distribution packages
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Create GitHub Release
uses: softprops/action-gh-release@v2
with:
files: dist/*
generate_release_notes: true
publish-to-pypi:
needs: build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/flash-attn-4
permissions:
id-token: write
steps:
- name: Download distribution packages
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
================================================
FILE: .github/workflows/publish.yml
================================================
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
outputs:
release-version: ${{ steps.extract_branch.outputs.branch }}
steps:
- name: Get the tag version
id: extract_branch
run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
shell: bash
- name: Create Release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes
shell: bash
build_wheels:
name: Build Wheel
needs: setup_release
strategy:
fail-fast: false
matrix:
# Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-22.04, ubuntu-22.04-arm]
python-version: ["3.10", "3.11", "3.12", "3.13"]
torch-version: ["2.6.0", "2.7.1", "2.8.0", "2.9.1", "2.10.0"]
cuda-version: ["12.9.1", "13.0.1"]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ["FALSE", "TRUE"]
exclude:
# CUDA 13.0 is only supported by PyTorch 2.9+
- torch-version: "2.6.0"
cuda-version: "13.0.1"
- torch-version: "2.7.1"
cuda-version: "13.0.1"
- torch-version: "2.8.0"
cuda-version: "13.0.1"
# No aarch64 PyTorch wheels for 2.6.0
- torch-version: "2.6.0"
os: ubuntu-22.04-arm
# PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE
- torch-version: "2.7.1"
cxx11_abi: "FALSE"
- torch-version: "2.8.0"
cxx11_abi: "FALSE"
- torch-version: "2.9.1"
cxx11_abi: "FALSE"
- torch-version: "2.10.0"
cxx11_abi: "FALSE"
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
cuda-version: ${{ matrix.cuda-version }}
torch-version: ${{ matrix.torch-version }}
cxx11_abi: ${{ matrix.cxx11_abi }}
release-version: ${{ needs.setup_release.outputs.release-version }}
upload-to-release: true
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install ninja packaging wheel twine
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
pip install setuptools==75.8.0
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
================================================
FILE: .gitignore
================================================
*.ncu-rep
*.sass
*.ptx
*.cubin
*.plk
.DS_store
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
.eggs/
# IDE-related
.idea/
.vscode/
# Dev
venv
# compile-time generated file
flash_attn_config.py
================================================
FILE: .gitmodules
================================================
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = amd-master
[submodule "third_party/aiter"]
path = third_party/aiter
url = https://github.com/ROCm/aiter.git
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
files: ^flash_attn/cute/.*\.py$
exclude: &cute_exclude |
(?x)^flash_attn/cute/(
flash_bwd|
flash_fwd|
flash_fwd_sm100|
interface|
)\.py$
- id: ruff-format
files: ^flash_attn/cute/.*\.py$
exclude: *cute_exclude
================================================
FILE: AI/DEBUG_2CTA.md
================================================
# Debugging GPU Kernel Hangs (Deadlocks) in CUTLASS DSL / 2CTA Kernels
## General Approach to Debugging Kernel Hangs
### Step 1: Build a minimal repro
Strip the test case down to the smallest input that triggers the hang:
- batch=1, nheads=1, smallest seqlen that hangs
- Single config, no loops, no benchmarking
- Add a timeout or run with `compute-sanitizer` so you can distinguish a hang from slow execution
### Step 2: Add printf to locate the hang
GPU `printf` (`cute.printf`) is the primary tool. The goal is binary search: narrow down which warp and which operation is blocked.
**Printf guards** — avoid print storms:
```python
# One thread per warp:
if cute.arch.thread_idx()[0] % 32 == 0:
cute.printf("...")
# One thread per CTA (elect_one is a context manager, not a bool):
with cute.arch.elect_one():
cute.printf("...")
# One specific thread:
if tidx == 0:
cute.printf("...")
```
**Strategy — coarse to fine:**
1. First, print at the entry/exit of each warp's main function (load, mma, softmax, correction). This tells you which warp is stuck.
2. Then add prints before/after each pipeline wait (`consumer_wait`, `producer_acquire`). This tells you which barrier is stuck.
3. Then print the barrier index, phase, and stage to understand the pipeline state.
**What to print:**
- CTA index (`cute.arch.block_idx()[0]`) — critical for multi-CTA debugging
- Pipeline stage index and phase
- Loop iteration count
- Whether a `try_wait` succeeds or fails (use `try_wait_token` parameter)
### Step 3: Identify the deadlock chain
A hang is always a cycle. Typical chain in a pipelined kernel:
```
MMA waiting for K from load (pipeline_kv full barrier)
-> Load finished but stuck in producer_tail (waiting for MMA to release empty barrier)
-> MMA can't release because it's waiting for K
```
Once you see which barrier is stuck, trace backwards: who is supposed to signal it, and why haven't they?
### Step 4: Vary the problem size systematically
Test with different sequence lengths / block counts to find the pattern:
| seqlen | n_blocks | Result |
|--------|----------|--------|
| 128 | 1 | ? |
| 256 | 2 | ? |
| 384 | 3 | ? |
| 512 | 4 | ? |
If the hang correlates with the number of visits to a pipeline stage (e.g., works for n_blocks <= kv_stages but fails when stages wrap around), the problem is likely in barrier tx_count or phase tracking.
### Step 5: Check barrier byte counts (tx_count)
For TMA-based pipelines, `arrive_and_expect_tx` sets the expected transaction byte count on an mbarrier. If the expected count doesn't match the actual bytes arriving, the barrier either:
- Fires too early (expected < actual) — causes data races
- Never fires (expected > actual) — causes hangs
In **2CTA / cluster mode**, both CTAs' TMAs signal the **same** cluster-level mbarrier. If each CTA's TMA contributes N bytes, the barrier receives 2N bytes total. The tx_count must be `N * cta_group_size`, not just `N`.
**All TMA pipelines need doubling** — Q, K, and V. Even though each CTA loads a different M-tile for Q, both CTAs' TMA operations still signal the same cluster-level barrier, so the expected byte count must account for both.
### Step 6: Check phase / parity tracking
`mbarrier_try_wait_parity` uses a single parity bit (0 or 1). If your pipeline state tracks phase as a monotonically increasing counter (0, 1, 2, 3, ...), you need `phase % 2` before passing it to the barrier wait. Without this, phase=2 looks like phase=0 to the hardware, which can cause waits on already-completed barriers or misses on pending ones.
### Step 7: Beware compiler-as-bug-source
If the kernel works WITH printf but hangs WITHOUT it, the printf is acting as a **compiler barrier**. The MLIR/LLVM backend cannot optimize through an opaque function call like printf, which prevents harmful instruction reordering.
Signs this is happening:
- A single `cute.printf("\n")` in the right function fixes the hang
- PTX fences (`fence_view_async_shared`, `fence_acq_rel_cluster`, `sync_warp`, `fence_proxy`) do NOT fix it — these affect hardware memory ordering, not compiler scheduling
- The fix is location-sensitive (printf in one function fixes it, in another doesn't)
Possible workarounds:
- `@dsl_user_op` decorator on pipeline methods to make them opaque to the compiler
- `asm volatile` barriers (if available in the DSL)
- Compare generated PTX/SASS with and without printf to identify what the compiler is reordering
- File a bug against the CUTLASS DSL / MLIR pipeline
---
## 2CTA-Specific Pitfalls
### tcgen05.commit with empty commit groups
`tcgen05.commit(mbar, mask, cta_group::2)` is supposed to signal an mbarrier after all pending MMA operations complete. But if there are **no pending operations** (empty commit group), the signal only reaches the local CTA's barrier, not the remote CTA's. Fix: use explicit `mbarrier_arrive(barrier, dst_cta_rank)` to both CTAs.
### producer_tail deadlock
The default `producer_tail` (inherited from sm90 pipelines) drains the pipeline by calling `producer_acquire` in a loop. In 2CTA mode this deadlocks because the consumer (MMA warp) may have already exited without releasing all stages. Fix: make `producer_tail` a no-op for 2CTA.
### Tile scheduler must account for cluster shape
Both CTAs in a cluster must get the **same** tile coordinate. Raw `blockIdx.x` assigns consecutive values to CTAs in the same cluster. Fix: divide `blockIdx.x` by `cluster_shape_m`.
### Cross-CTA vs per-CTA pipelines
Pipelines where CTA 1's threads remotely arrive on CTA 0's barriers need cluster-sized cooperative group counts. Pipelines that are purely local to each CTA keep per-CTA counts.
### Softmax masking offset
Causal mask row positions must account for the CTA's position within the cluster. Multiply `m_block` by `cta_group_size` when computing mask coordinates.
================================================
FILE: AI/RACECHECK_TMA_HAZARD.md
================================================
# compute-sanitizer racecheck hazard with `cp.async.bulk`
## Summary
`compute-sanitizer --tool=racecheck` reports false-positive shared-memory race
hazards when `cp.async.bulk` (raw-address TMA) is used in a cross-warp
producer/consumer pipeline inside a dynamic loop. The same pattern with
`cp.async.bulk.tensor` (descriptor-based TMA) reports **zero hazards**.
The fix for the flash backward kernel is to switch the LSE/dPsum copies from
`CopyBulkG2SOp` (`cp.async.bulk`) to `CopyBulkTensorTileG2SOp`
(`cp.async.bulk.tensor`) using `cpasync.make_tiled_tma_atom`.
## Affected code
`flash_attn/cute/flash_bwd_sm100.py` — the SM100 backward attention kernel.
Only **LSE** and **dPsum** buffers are affected because they are the only
TMA-loaded buffers consumed by thread-level shared memory reads (`lds`).
Q/K/V/dO are consumed by UMMA hardware instructions, which do not generate
thread-level `lds` and therefore never trigger racecheck.
## Root cause
racecheck instruments every shared memory access and checks for conflicting
accesses lacking a recognized happens-before relationship.
**`cp.async.bulk` (raw address):** the sanitizer attributes the smem write to
the issuing thread (thread 0 of warp 0 via `elect_one`). When warp 1 issues
`ld.shared.b32` from the same addresses, the sanitizer searches for a
happens-before edge. The only sync is `mbarrier.try_wait.parity` on warp 1
paired with `mbarrier::complete_tx::bytes` completion from the hardware. The
sanitizer does not model this as happens-before across warps in a dynamic loop.
**`cp.async.bulk.tensor` (TMA descriptor):** the TMA engine is a separate
hardware unit. The sanitizer does not attribute the smem write to any thread.
No writer thread means no hazard pair, so no race is reported.
### Instruction comparison
| Variant | PTX | racecheck |
|---------|-----|-----------|
| Raw (cta scope) | `cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes` | **hazard** |
| Raw (cluster scope) | `cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes` | **hazard** |
| Descriptor 1D | `cp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean |
| Descriptor 2D | `cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean |
### `--racecheck-memcpy-async=no` does not help
This flag controls the older `cp.async` (sm80) instruction family, not
`cp.async.bulk`. The hazard persists with `--racecheck-memcpy-async=no`.
## Proof that it is a false positive
1. **Data correctness** — all variants produce bit-identical results.
2. **Single-warp test** — one warp does both TMA write and thread read in the
same loop; racecheck reports zero hazards with the same mbarrier sync.
3. **Unrolled loop** — fully unrolling (`unroll_full=True`) reports zero
hazards; racecheck tracks mbarrier within straight-line code but not across
a dynamic branch back-edge between warps.
4. **Named barrier** — adding `bar.sync` per iteration between producer and
consumer warps eliminates the hazard; the sync is correct, racecheck just
needs a primitive it recognizes.
5. **Descriptor TMA** — switching to `cp.async.bulk.tensor` with identical
pipeline code eliminates the hazard; the mbarrier protocol is correct.
## Minimal reproducers
### `AI/` (preferred, cleaner)
| File | Copy instruction | Result |
|------|-----------------|--------|
| `racecheck_repro_1d_bulk.py` | `cp.async.bulk` (raw address) | **1 error** |
| `racecheck_repro_1d_tensor.py` | `cp.async.bulk.tensor.1d` (TMA descriptor) | **0 hazards** |
Both are ~75-line self-contained kernels: 2 warps, 4 blocks, 2-stage double
buffering with `PipelineTmaAsync`. Identical pipeline protocol — only the copy
instruction differs.
```bash
python AI/racecheck_repro_1d_bulk.py # correctness
CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py # 1 error
compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py # 0 hazards
```
### `benchmarks/` (earlier, more variants)
| File | What it tests | Result |
|------|--------------|--------|
| `racecheck_false_positive_repro.py` | `cp.async.bulk` + mbarrier in cross-warp loop | 1 error |
| `racecheck_1d_raw_ptx.py` | Inline PTX `cp.async.bulk.shared::cta.global` | 1 error |
| `racecheck_tma2d_repro.py` | `cp.async.bulk.tensor.2d` via `make_tiled_tma_atom` | 0 hazards |
| `racecheck_tma1d_descriptor.py` | `cp.async.bulk.tensor.1d` via `make_tiled_tma_atom` | 0 hazards |
## PTX-level analysis
Dumped PTX for both `AI/` reproducers (`CUTE_DSL_KEEP_PTX=1`). The generated
code is byte-for-byte identical except for the single copy instruction:
```
# racecheck_repro_1d_bulk.py (HAZARD)
cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes
[%r42], [%rd12], %r43, [%r6+-16];
# racecheck_repro_1d_tensor.py (CLEAN)
cp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint
[%r43], [%rd1, {%r71}], [%r6+-16], %rd8;
```
All mbarrier operations (init, `fence.mbarrier_init.release.cluster`,
`arrive.expect_tx`, `try_wait.parity`, `arrive.release`,
`fence.proxy.async.shared::cta`, `bar.warp.sync`) are identical.
### racecheck error output
```
Error: Race reported between Write access at ...+0x430 in racecheck_repro_1d_bulk.py:46
and Read access at ...+0x770 in racecheck_repro_1d_bulk.py:55 [248 hazards]
and Read access at ...+0x7a0 in racecheck_repro_1d_bulk.py:55 [248 hazards]
and Read access at ...+0x7d0 in racecheck_repro_1d_bulk.py:55 [248 hazards]
and Read access at ...+0x800 in racecheck_repro_1d_bulk.py:55 [248 hazards]
```
- **Write** (0x430) = line 46: `cute.copy(atom, src, s, mbar_ptr=...)` — the
`cp.async.bulk` instruction
- **Read** (0x770–0x800) = line 55: `dst[...] = s[...]` — four `ld.shared.b32`
in the consumer warp
## Fix
Change `copy_stats` in the load function from:
```python
copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)
copy_stats = partial(cute.copy, copy_atom_stats)
```
to a descriptor-based TMA using `cpasync.make_tiled_tma_atom` with
`CopyBulkTensorTileG2SOp`. This generates `cp.async.bulk.tensor.1d` instead of
`cp.async.bulk`, which racecheck does not instrument.
The pipeline protocol (mbarrier init, arrive_expect_tx, try_wait_parity,
consumer_release) remains identical.
## Backup
`flash_attn/cute/flash_bwd_sm100_gmem_fix.py` contains a working but slower
fix where compute warps read LSE/dPsum directly from global memory, bypassing
the TMA smem pipeline entirely.
## Investigation timeline
1. Observed 2 racecheck errors on LSE and dPsum in `flash_bwd_sm100.py`.
Q/K/V/dO clean.
2. Noticed Q/K/V/dO use UMMA consumers (no thread `lds`) while LSE/dPsum use
thread-level `autovec_copy` from smem — explains why only LSE/dPsum trigger.
3. Built minimal 2-warp pipeline kernel reproducing the hazard.
4. Single-warp version clean — same mbarrier, same addresses.
5. Fully-unrolled version clean — racecheck tracks mbarrier within
straight-line code.
6. `bar.sync` per iteration fixes it — racecheck needs a sync it recognizes
across the loop back-edge.
7. `cp.async.bulk.tensor.2d` clean — different instruction, same pipeline.
8. `cp.async.bulk.tensor.1d` clean — issue is raw vs descriptor, not
dimensionality.
9. Raw inline PTX `cp.async.bulk.shared::cta.global` also triggers — not a
CuTe DSL abstraction issue.
10. Dumped PTX for both `AI/` reproducers — confirmed byte-identical code
except for the copy instruction. Sanitizer attributes smem write to
issuing thread for `cp.async.bulk` but not for `cp.async.bulk.tensor`.
11. Confirmed `--racecheck-memcpy-async=no` does not suppress the hazard —
flag targets older `cp.async`, not `cp.async.bulk`.
================================================
FILE: AI/SM90_BLOCK_SIZE_TUNING.md
================================================
# SM90 Block Size Tuning Guide
How to choose tile sizes and MMA configurations for FlashAttention on Hopper (SM90).
## Tool
Use `flash_attn/cute/sm90_config_search.py` to enumerate feasible configs:
```bash
# Both fwd and bwd
python flash_attn/cute/sm90_config_search.py --headdim 128
# Forward only
python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128
# Backward only, custom tile choices
python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-m 64,80 --tile-n 64,96
```
## Hardware Constraints (H100)
- **SMEM**: 228 KB total. We reserve ~3 KB for LSE, dPsum, and mbarriers, leaving **224 KB** for tensor buffers.
- **Registers**: Controlled via `setmaxnreg`. Budget per MMA warp group:
- 2 WG: 240 regs/thread, minus 24 overhead = **216 usable**
- 3 WG: 160 regs/thread, minus 32 overhead = **128 usable**
- **GMMA atom**: Always M=64. The effective M dimension (after swap) must be divisible by 64. N dimension must be divisible by `atom_layout_n * 8`.
## Architecture: Warp Groups
Each SM90 backward kernel has `num_wg + 1` warp groups (128 threads each):
- **WG0** (producer): TMA loads for Q, K, V, dO, LSE, dPsum
- **WG1** (producer): dQaccum store (TMA reduce-add to gmem)
- **WG2..WG(num_wg)** (MMA consumers): All GEMMs
For forward: `num_wg` MMA WGs + 1 producer WG. `tile_m = num_wg * 64` (no swap).
## Key Decisions
### 1. Number of Warp Groups (num_wg)
| num_wg | tile_m (fwd) | Threads | Reg budget | Best for |
|--------|-------------|---------|------------|----------|
| 2 | 128 | 384 | 216/thread | hdim <= 128 |
| 3 | 192 | 512 | 128/thread | hdim 129-192 |
More WGs = larger tile_m = better M-direction parallelism, but tighter register budget and higher smem usage.
### 2. swap_AB
Each MMA can optionally swap its A and B operands. This transposes the output tile, exchanging which dimension maps to M (must be divisible by 64) and which maps to N.
**When to swap:**
- If the natural M dimension isn't divisible by 64 but N is (e.g., tile_m=80 for SdP)
- To change which operand is in registers vs shared memory
**Forward**: No swap needed since tile_m = num_wg * 64 is always divisible by 64.
**Backward** (5 MMAs):
- **SdP** (S=Q@K^T, dP=dO@V^T): output (tile_m, tile_n). Swap if tile_m % 64 != 0.
- **dKV** (dK=dS^T@Q, dV=P^T@dO): output (tile_n, hdim/hdimv). Swap if tile_n % 64 != 0 but hdim % 64 == 0.
- **dQ** (dQ=dS@K): output (tile_m, hdim). Swap if tile_m % 64 != 0 but hdim % 64 == 0.
### 3. AtomLayout
The `atom_layout` distributes WGs across the M and N dimensions of an MMA output. With `num_wg` MMA WGs and `atom_layout_m = A`:
- M direction: A warp groups, each handling M/A rows
- N direction: num_wg/A warp groups, each handling N/(num_wg/A) columns
After swap, the atom layout is also swapped.
**Impact on smem traffic**: More WGs in the N direction (`wg_n` larger) means each instruction reads a smaller B slice, but more instructions total read overlapping A slices. Fewer WGs in N (`wg_n` smaller) means fewer instructions but each reads a larger B slice. Typically **smaller wg_n = less total smem traffic**.
### 4. mma_dkv_is_rs (Register-Source for dKV)
When `AtomLayoutMSdP == 1 && AtomLayoutNdKV == num_wg && SdP_swapAB && !dKV_swapAB`, the P and dS matrices can be kept in registers and fed directly as the A operand of dV and dK GEMMs. This:
- **Eliminates sP from smem** (saves tile_m * tile_n * 2 bytes)
- **Eliminates P R2S store** from smem traffic
- **Eliminates A operand reads** for dK and dV GEMMs
This is a significant optimization — always preferred when the conditions are met.
### 5. Pipeline Staging
**Forward**:
- Q: 1 stage (loaded once per n_block tile)
- K, V: 2 stages (double-buffered, pipelined with TMA)
- O: overlaps with Q in smem (reuses same buffer at epilogue)
**Backward**:
- Q: always 2 stages (double-buffered)
- dO: 2 stages if smem allows (matches Q pipeline), else 1 stage
- PdS: 1 stage
- K, V: persistent in smem (loaded once per n_block)
## Register Accounting
Accumulator registers per thread per WG = `M * N / (num_wg * 128)`, where M x N is the output tile.
**Forward peak registers**:
- With WG overlap: `regs_S + regs_P + regs_O` (S, P in bf16, O all live)
- Without overlap: `regs_S + regs_O` (S and O alternate, P reuses S regs)
Where `regs_P = regs_S / 2` (bf16 vs f32).
**Backward peak registers**:
- `max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV`
- S and dP accumulators are both live (S needed for softmax while dP computes)
- dQ reuses S+dP register space after they're consumed
- dK and dV accumulate across m_block iterations
## SMEM Accounting
Sum of tensor buffers (ignoring alignment padding, which is small):
**Forward**: `max(sQ, sO) + sK*2 + sV*2 + sP`
- sQ = tile_m * hdim * 2
- sK = tile_n * hdim * 2 * 2 stages
- sV = tile_n * hdimv * 2 * 2 stages
- sO = tile_m * hdimv * 2 (overlaps with sQ)
- sP = tile_m * tile_n * 2 (0 if RS)
**Backward**: `sQ*2 + sK + sV + sdO*dO_stage + sP + sdS + sdQaccum`
- sQ = tile_m * hdim * 2 * 2 stages
- sK = tile_n * hdim * 2
- sV = tile_n * hdimv * 2
- sdO = tile_m * hdimv * 2 * dO_stage
- sP = tile_m * tile_n * 2 (0 if mma_dkv_is_rs)
- sdS = tile_m * tile_n * 2
- sdQaccum = tile_m * hdim * 4 (f32)
## SMEM Traffic
Per-iteration smem bandwidth consumed. Each GMMA instruction reads:
- **A operand**: 64 * K_red * 2 bytes (0 if register-source)
- **B operand**: (N_eff / wg_n) * K_red * 2 bytes
Total instructions = (M_eff / 64) * wg_n. Each instruction independently reads A and B from smem.
Additional traffic: R2S stores for P, dS (bf16), dQ smem store + TMA load (f32).
**Traffic per block** (traffic / (tile_m * tile_n)) normalizes across tile sizes for comparison. Lower is better.
## Example Configs
### hdim=128 (Forward)
Best: tile_m=128, tile_n=192, RS, 2 WG. 224K smem, 9.3 tr/blk.
### hdim=128 (Backward, non-causal)
C++ FA3 config: tile_m=80, tile_n=128, SdP_swap=T, dKV_swap=F, dQ_swap=T, aSdP=1, adKV=2. mma_dkv_is_rs=True. 204K smem, 208 regs, 39.6 tr/blk.
### hdim=192 (Backward)
3 WG, tile_m=64, tile_n=96, SdP_swap=F, dKV_swap=T, adKV=1 or 3. 216K smem, 128 regs. This is the only feasible tile_n > 64 for hdim=192 due to register pressure.
### hdim=192, hdimv=128 (DeepSeek shape)
With 3 WG: need AtomLayoutNdKV=3 (since hdimv=128 not divisible by 3). tile_n=96, 212K smem.
With 2 WG: tile_n=112 feasible at 210K smem, or tile_n=64 at 168K smem.
================================================
FILE: AI/SM90_R2P_MASKING_SASS.md
================================================
# SM90 FWD R2P Masking — SASS Investigation
## SASS Instruction Counts (hdim=128, seqlen=113, tile_n=128)
With tile_n=128, SM90 has 32 accumulator elements per row (1 chunk of 32).
### Non-causal (seqlen-only masking)
| Metric | Old (no R2P) | New (R2P) | Delta |
|--------|-------------|-----------|-------|
| **Total instructions** | 3104 | 3072 | **-32 (-1%)** |
| R2P | 0 | 4 | +4 |
| FSEL | 70 | 70 | 0 |
| ISETP | 55 | 22 | **-33** |
| SHF | 69 | 73 | +4 |
| LOP3 | 51 | 56 | +5 |
R2P replaces 33 ISETP (integer set-predicate) instructions with 4 R2P + a few LOP3/SHF. Net savings: 32 instructions. The 4 R2P instructions each convert one byte of a 32-bit bitmask into 7 predicates, covering all 32 elements (4 × 8 bits = 32).
### Causal
| Metric | Old (no R2P) | New (R2P) | Delta |
|--------|-------------|-----------|-------|
| **Total instructions** | 5008 | 4857 | **-151 (-3%)** |
| R2P | 0 | 24 | +24 |
| FSEL | 200 | 200 | 0 |
| ISETP | 225 | 22 | **-203** |
| SHF | 104 | 105 | +1 |
| LOP3 | 81 | 105 | +24 |
Much larger savings. The causal kernel applies masking per-row (each row has a different col_limit), so it has many more masking operations. 24 R2P instructions replace 203 ISETP instructions, saving 151 total.
### Local (sliding window, wl=64 wr=0)
| Metric | Old (no R2P) | New (R2P) | Delta |
|--------|-------------|-----------|-------|
| **Total instructions** | 7296 | 6217 | **-1079 (-15%)** |
| R2P | 0 | 32 | +32 |
| FSEL | 522 | 266 | **-256** |
| ISETP | 554 | 22 | **-532** |
| SHF | 115 | 73 | -42 |
| LOP3 | 96 | 56 | -40 |
Dramatic savings. Local masking has two bounds (left + right) per row, doubling the masking work. R2P eliminates 532 ISETP and 256 FSEL instructions, saving 1079 total (15% of kernel).
## How R2P Works in SASS
The compiler generates this pattern:
```
SHF.R.U32.HI R9, RZ, R9, R16 ; shift to create bitmask
R2P PR, R9, 0x7f ; byte 0 → predicates P0-P6
FSEL R15, R36, -INF, P6 ; apply P6: keep or mask to -inf
R2P PR, R9.B1, 0x7f ; byte 1 → predicates P0-P6
FSEL R52, R52, -INF, P6 ; apply P6
R2P PR, R9.B2, 0x7f ; byte 2
...
R2P PR, R9.B3, 0x7f ; byte 3
```
Each `R2P` converts 7 bits of a register byte into 7 predicate registers simultaneously (1 instruction instead of 7 `ISETP`). The subsequent `FSEL` instructions use these predicates for conditional masking.
### Handling the leftover bits (32 is not divisible by 7)
The `0x7f` immediate tells R2P to map bits 0-6 of each byte to P0-P6, but bit 7 (the MSB of each byte) is not covered. For 32 elements across 4 bytes, that's 4 leftover elements (bits 7, 15, 23, 31). The compiler handles these with separate `LOP3.LUT` or `ISETP` instructions:
```
R2P PR, R12, 0x7f ; bits 0-6 → P0-P6 (7 elements)
14× FSEL using P0-P6 ; apply to 7 cols × 2 rows
LOP3.LUT P0, RZ, R12, 0x80, ... ; test bit 7 (1 element)
2× FSEL using P0
R2P PR, R12.B1, 0x7f ; bits 8-14 → P0-P6 (7 elements)
14× FSEL using P0-P6
LOP3.LUT P1, RZ, R12, 0x8000, ..; test bit 15 (1 element)
2× FSEL using P1
R2P PR, R12.B2, 0x7f ; bits 16-22 → P0-P6 (7 elements)
14× FSEL using P0-P6
LOP3.LUT P0, RZ, R12, 0x800000,..; test bit 23 (1 element)
2× FSEL using P0
R2P PR, R12.B3, 0x7f ; bits 24-30 → P0-P6 (7 elements)
14× FSEL using P0-P6
ISETP.GT P0, R12, -1 ; test bit 31 (sign bit) (1 element)
2× FSEL using P0
```
Total: 4×7 = 28 elements via R2P + 4 elements via LOP3/ISETP = 32. Each R2P replaces 7 ISETP with 1 instruction, so net savings is `(7-1) × 4 = 24` predicate-generation instructions per mask application. Additionally, ptxas can overlap R2P with FSEL since they write to separate predicate registers.
## Performance Impact
| Case | Old (ms) | New (ms) | Speedup |
|------|----------|----------|---------|
| Causal hdim=64 s=8192 | 2.463 | 2.473 | ~0% |
| Causal hdim=128 s=8192 | 1.937 | 1.944 | ~0% |
| Local hdim=64 s=8192 | 0.394 | 0.346 | **+14%** |
| Local hdim=128 s=8192 | 0.237 | 0.222 | **+7%** |
| Non-causal hdim=128 s=4096 | 1.742 | 1.728 | ~1% |
Causal sees no perf gain despite fewer instructions because masking is a tiny fraction of total work (dominated by WGMMA). Local sees significant gains because the sliding window has many partially-masked blocks where masking overhead matters more.
================================================
FILE: AI/VARLEN_PREPROCESS_TILE_BUG.md
================================================
# Varlen Preprocess Tile Mismatch Bug
## Summary
`SeqlenInfo.create` in `flash_bwd_preprocess.py` defaulted `tile=128`, but the backward kernel uses `tile_m=m_block_size` (e.g. 64 for causal SM90). This caused the preprocess to zero dq_accum and write lse_log2/dpsum at wrong padded offsets for all batches after batch 0.
## How padded_offset works
For varlen, buffers like dq_accum are laid out with tile-aligned gaps between sequences:
```
padded_offset_q = ((offset_q + batch_idx * tile_m) // tile_m) * tile_m
```
The gap size depends on `tile_m`. With `tile_m=64` vs `tile_m=128`, batch 1 at `offset_q=128` gets:
- tile=64: padded_offset = ((128 + 64) // 64) * 64 = **192**
- tile=128: padded_offset = ((128 + 128) // 128) * 128 = **256**
The preprocess was zeroing at 256, the backward was writing at 192.
## Symptoms
- Tests pass in isolation (torch.empty gets clean memory)
- Tests fail when run in sequence (CUDA memory caching reuses NaN-polluted memory)
- dq_accum valid positions contain NaN after backward kernel
- `torch.zeros` for dq_accum masks the bug (zeroes everywhere, including the "right" offsets)
- compute-sanitizer shows 0 errors (addresses are valid, just wrong offsets within the buffer)
## Fix
```python
# flash_bwd_preprocess.py line 216
# Before:
seqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ)
# After:
seqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m)
```
## Lesson
Any code computing `padded_offset` for varlen buffers must use the same tile size as the kernel that allocated and accesses those buffers. The `SeqlenInfo.create` default `tile=128` is a trap when `m_block_size != 128`.
================================================
FILE: AI/racecheck_repro_1d_bulk.py
================================================
"""Minimal reproducer: cp.async.bulk (raw address) triggers racecheck hazard.
Warp 0 loads via cp.async.bulk, warp 1 reads from smem after mbarrier wait.
Pipeline is correctly synchronized but racecheck reports 1 error.
python AI/racecheck_repro_1d_bulk.py # correctness
CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py # 1 error
"""
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
from cutlass import Float32, Int32
import cutlass.pipeline
from cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state
import cuda.bindings.driver as cuda
import torch
N_BLKS, TILE = 4, 128
N_STG = 2
@cute.kernel
def kernel(g_src: cute.Tensor, g_dst: cute.Tensor):
smem = cutlass.utils.SmemAllocator()
s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128)
s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8)
tidx, _, _ = cute.arch.thread_idx()
warp, lane = tidx // 32, tidx % 32
pipe = PipelineTmaAsync.create(
barrier_storage=s_mbar.iterator, num_stages=N_STG,
producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),
consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),
tx_count=TILE * 4, defer_sync=False,
)
src = cute.local_tile(g_src, (TILE,), (None,))
dst = cute.local_tile(g_dst, (TILE,), (None,))
if warp == 0:
ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG)
for blk in cutlass.range(N_BLKS, unroll=1):
pipe.producer_acquire(ps)
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)
with cute.arch.elect_one():
cute.copy(atom, src[None, blk], s[None, ps.index],
mbar_ptr=pipe.producer_get_barrier(ps))
ps.advance()
pipe.producer_tail(ps)
if warp == 1:
cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG)
for blk in cutlass.range(N_BLKS, unroll=1):
pipe.consumer_wait(cs)
for i in cutlass.range_constexpr(TILE // 32):
dst[lane + i * 32, blk] = s[lane + i * 32, cs.index]
cute.arch.fence_view_async_shared()
cute.arch.sync_warp() # Ned sync_warp as only 1 thread will signal in consumer_release
pipe.consumer_release(cs)
cs.advance()
@cute.jit
def go(g_src, g_dst, stream):
kernel(g_src, g_dst).launch(grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream)
if __name__ == "__main__":
src = torch.arange(TILE * N_BLKS, device="cuda", dtype=torch.float32)
dst = torch.zeros_like(src)
go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16),
cuda.CUstream(torch.cuda.current_stream().cuda_stream))
torch.cuda.synchronize()
assert torch.equal(src, dst), f"FAIL: max diff={torch.abs(src - dst).max().item()}"
print("PASS")
================================================
FILE: AI/racecheck_repro_1d_tensor.py
================================================
"""Minimal reproducer: cp.async.bulk.tensor.1d (descriptor TMA) passes racecheck.
Same pipeline as racecheck_repro_1d_bulk.py but uses make_tiled_tma_atom to
create a TMA descriptor, which generates cp.async.bulk.tensor.1d PTX.
python AI/racecheck_repro_1d_tensor.py # correctness
CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py # 0 hazards
"""
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
from cutlass import Float32, Int32
import cutlass.pipeline
from cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state
import cuda.bindings.driver as cuda
import torch
N_BLKS, TILE = 4, 128
N_STG = 2
@cute.kernel
def kernel(g_dst: cute.Tensor, tma_atom: cute.CopyAtom, tma_tensor: cute.Tensor):
smem = cutlass.utils.SmemAllocator()
s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128)
s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8)
tidx, _, _ = cute.arch.thread_idx()
warp, lane = tidx // 32, tidx % 32
pipe = PipelineTmaAsync.create(
barrier_storage=s_mbar.iterator, num_stages=N_STG,
producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),
consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),
tx_count=TILE * 4, defer_sync=False,
)
tma_s, tma_g = cpasync.tma_partition(
tma_atom, Int32(0), cute.make_layout(1),
cute.group_modes(s, 0, 1),
cute.group_modes(cute.local_tile(tma_tensor, (TILE,), (None,)), 0, 1),
)
dst = cute.local_tile(g_dst, (TILE,), (None,))
if warp == 0:
with cute.arch.elect_one():
cpasync.prefetch_descriptor(tma_atom)
if warp == 0:
ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG)
for blk in cutlass.range(N_BLKS, unroll=1):
pipe.producer_acquire(ps)
cute.copy(tma_atom, tma_g[None, blk], tma_s[None, ps.index],
tma_bar_ptr=pipe.producer_get_barrier(ps))
ps.advance()
pipe.producer_tail(ps)
if warp == 1:
cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG)
for blk in cutlass.range(N_BLKS, unroll=1):
pipe.consumer_wait(cs)
for i in cutlass.range_constexpr(TILE // 32):
dst[lane + i * 32, blk] = s[lane + i * 32, cs.index]
cute.arch.fence_view_async_shared()
cute.arch.sync_warp() # Ned sync_warp as only 1 thread will signal in consumer_release
pipe.consumer_release(cs)
cs.advance()
@cute.jit
def go(g_src, g_dst, stream):
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileG2SOp(), g_src, cute.make_layout(TILE), (TILE,),
)
kernel(g_dst, tma_atom, tma_tensor).launch(
grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream,
)
if __name__ == "__main__":
src = torch.arange(TILE * N_BLKS, device="cuda", dtype=torch.float32)
dst = torch.zeros_like(src)
go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16),
cuda.CUstream(torch.cuda.current_stream().cuda_stream))
torch.cuda.synchronize()
assert torch.equal(src, dst), f"FAIL: max diff={torch.abs(src - dst).max().item()}"
print("PASS")
================================================
FILE: AUTHORS
================================================
Tri Dao, trid@cs.stanford.edu
================================================
FILE: CLAUDE.md
================================================
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
FlashAttention-4 (FA4) — fast, memory-efficient exact attention kernels written in Python using CuTeDSL (NVIDIA CUTLASS DSL). Kernels are compiled to PTX/CUBIN at runtime. Targets Hopper (SM90) and Blackwell (SM100/SM110) GPUs. Package name: `flash-attn-4`.
The repository also contains older generations (FA2 in top-level `csrc/`, FA3 in `hopper/`) but active development is on FA4 in `flash_attn/cute/`.
## Build & Install
```bash
pip install flash-attn-4
# or dev install:
pip install -e "flash_attn/cute[dev]"
```
Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.2.10`.
## Running Tests
```bash
pytest tests/cute/test_flash_attn.py
pytest tests/cute/test_flash_attn.py -k "test_flash_attn_output" -x # single test
pytest tests/cute/test_flash_attn_varlen.py
pytest tests/cute/test_mask_mod.py
pytest tests/cute/test_score_mod.py
pytest tests/cute/test_block_sparsity.py
```
### Fast two-pass testing
Compilation dominates test time. The fast workflow separates compilation (parallel, no GPU needed) from execution (uses cached binaries):
```bash
# Pass 1: compile all kernels in parallel using FakeTensorMode (no GPU memory allocation)
FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py
# Pass 2: run tests using cached compiled kernels
FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py
```
- `FLASH_ATTENTION_FAKE_TENSOR=1` — uses PyTorch FakeTensorMode to compile kernels without allocating GPU memory or running them.
- `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` — enables persistent disk cache at `/tmp/${USER}/flash_attention_cute_dsl_cache/`.
- `-n 256` — pytest-xdist parallel workers (only useful in the compilation pass).
Tests are parametrized over dtype (fp16/bf16), head dimension (64, 96, 128), sequence length, causal/non-causal, and MHA/GQA/MQA.
If you get OOM errors running tests or benchmarks, use `nvidia-smi` to find a free GPU and select it with `CUDA_VISIBLE_DEVICES=<id>`.
## Linting
Pre-commit uses ruff on `flash_attn/cute/` files. Large kernel files (`flash_bwd.py`, `flash_fwd.py`, `flash_fwd_sm100.py`, `interface.py`) are excluded from auto-formatting.
```bash
ruff check flash_attn/cute/ --fix
ruff format flash_attn/cute/
```
## Code Architecture
### Public API (`flash_attn/cute/interface.py`)
Two entry points exported from `flash_attn/cute/__init__.py`:
- `flash_attn_func(q, k, v, ...)` — standard attention
- `flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)` — variable-length
Key parameters: `causal`, `window_size_left/right`, `softmax_scale`, `softcap`, `score_mod`, `mask_mod`, `block_sparse_tensors`, `num_splits`, `pack_gqa`, `m_block_size`, `n_block_size`, `num_threads`.
Tensor layout: `(batch, seqlen, num_heads, head_dim)`, last dim contiguous, 16-byte aligned.
### Forward Kernels
- `flash_fwd.py` — `FlashAttentionForwardSm90`: Hopper forward. No SplitKV or paged KV.
- `flash_fwd_sm100.py` — `FlashAttentionForwardSm100`: Blackwell forward. Full features including SplitKV, paged KV cache, persistent kernels, 2CTA instructions.
- `flash_fwd_combine.py` — `FlashAttentionForwardCombine`: merges SplitKV partial results.
### Backward Kernels
- `flash_bwd.py` — `FlashAttentionBackwardSm80`: Ampere backward (base).
- `flash_bwd_sm90.py` — `FlashAttentionBackwardSm90`: Hopper backward.
- `flash_bwd_sm100.py` — `FlashAttentionBackwardSm100`: Blackwell backward with 2CTA and block sparse support.
- `flash_bwd_preprocess.py` / `flash_bwd_postprocess.py` — auxiliary backward kernels.
### Core Abstractions
- `softmax.py` — Online softmax with row_max/row_sum tracking, score modifier support.
- `mask.py` — `AttentionMask`: causal, local/sliding window, block sparse, mask_mod application.
- `block_info.py` — `BlockInfo`: tile dimensions, n/m block range computation for causal/local masking.
- `seqlen_info.py` — `SeqlenInfoQK`: sequence length and offset tracking for varlen.
- `pipeline.py` — `PipelineStateSimple`: circular buffer index/phase management for pipelined loads.
- `tile_scheduler.py` — Tile scheduling strategies (single tile, varlen-aware, persistent).
- `copy_utils.py` — Type-converting copies, shared-to-register loads, TMA copy atoms.
- `named_barrier.py` — Named barrier enums for warp synchronization.
### Architecture-Specific Helpers
- `hopper_helpers.py` — SM90 warp-group GEMM, shared memory layout creation, fence/commit/wait.
- `blackwell_helpers.py` — SM100 UMMA-based GEMM, PTX-optimized paths, 2CTA support.
- `mma_sm100_desc.py` — Hardware MMA descriptor enums (formats, saturation, scaling).
### Other Components
- `pack_gqa.py` — Packs multiple Q heads per KV head for efficient GQA.
- `paged_kv.py` — `PagedKVManager`: paged KV cache with TMA support.
- `fast_math.py` — exp2 polynomial coefficients, softcap score_mod creation.
- `utils.py` — Hash functions for compile cache keys, warp reductions, predicates.
- `cache_utils.py` — JIT compilation cache management.
- `cute_dsl_utils.py` — Patched `cute.compile` that optionally dumps SASS.
### Compilation & Caching
Kernels are JIT-compiled. Cache key includes dtype, head_dim, causal, mask/score_mod hashes, architecture, block sizes. Caching levels: in-memory LRU + optional disk cache via `get_jit_cache()`.
Env vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PTX), `CUTE_DSL_PTXAS_PATH` (custom ptxas).
## Key Patterns
- Compile-time constants use `cutlass.Constexpr[type]` for kernel specialization.
- Score/mask modifiers are user-defined `@cute.jit` callables injected into the kernel at compile time.
- Forward execution: load Q tile → loop over K/V blocks (pipelined) → online softmax accumulation → store O and LSE.
- 2CTA instructions (SM100, hdim=128): both CTAs in a cluster coordinate via shared mbarriers; tx_count must be multiplied by `cta_group_size`.
## Debugging GPU Kernels
See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`.
Key tools:
- `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output
- `compute-sanitizer --tool=racecheck` (beware false positives with raw TMA)
- `CUTE_DSL_KEEP_PTX=1` and `CUTE_DSL_LINEINFO=1` for PTX inspection and sanitizer source mapping
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: MANIFEST.in
================================================
recursive-include csrc *.cu
recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include csrc *.py
recursive-include flash_attn *.cu
recursive-include flash_attn *.h
recursive-include flash_attn *.cuh
recursive-include flash_attn *.cpp
recursive-include flash_attn *.hpp
================================================
FILE: Makefile
================================================
clean_dist:
rm -rf dist/*
create_dist: clean_dist
python setup.py sdist
upload_package: create_dist
twine upload dist/*
================================================
FILE: README.md
================================================
# FlashAttention
This repository provides the official implementation of FlashAttention and
FlashAttention-2 from the
following papers.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.

**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
Tri Dao
Paper: https://tridao.me/publications/flash2/flash2.pdf

## Usage
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
## FlashAttention-3 beta release
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
Blogpost: https://tridao.me/blog/2024/flash3/
Paper: https://tridao.me/publications/flash3/flash3.pdf

This is a beta release for testing / benchmarking before we integrate that with
the rest of the repo.
Currently released:
- FP16 / BF16 forward and backward, FP8 forward
Requirements: H100 / H800 GPU, CUDA >= 12.3.
We highly recommend CUDA 12.8 for best performance.
To install:
```sh
cd hopper
python setup.py install
```
To run the test:
```sh
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
```
Once the package is installed, you can import it as follows:
```python
import flash_attn_interface
flash_attn_interface.flash_attn_func()
```
## FlashAttention-4 (CuTeDSL)
FlashAttention-4 is written in CuTeDSL and optimized for Hopper and Blackwell GPUs (e.g. H100, B200).
To install:
```sh
pip install flash-attn-4
```
Once installed, you can use it as follows:
```python
from flash_attn.cute import flash_attn_func
out = flash_attn_func(q, k, v, causal=True)
```
## Installation and features
**Requirements:**
- CUDA toolkit or ROCm toolkit
- PyTorch 2.2 and above.
- `packaging` Python package (`pip install packaging`)
- `psutil` Python package (`pip install psutil`)
- `ninja` Python package (`pip install ninja`) *
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
**To install:**
```sh
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```sh
python setup.py install
```
If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`:
```sh
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```
**Interface:** `src/flash_attention_interface.py`
### NVIDIA CUDA Support
**Requirements:**
- CUDA 12.0 and above.
We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
container from Nvidia, which has all the required tools to install FlashAttention.
FlashAttention-2 with CUDA currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
### AMD ROCm Support
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.
**Requirements:**
- ROCm 6.0 and above.
We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.
#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200x, MI250x, MI300x, and MI355x GPUs.
2. Datatype fp16 and bf16
3. Both forward's and backward's head dimensions up to 256.
#### Triton Backend
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.
The Triton backend kernels are provided by the [aiter](https://github.com/ROCm/aiter) package, included as a git submodule at `third_party/aiter` and automatically installed during setup.
To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Flash Attention:
```sh
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
```
To use a specific aiter commit (e.g., for testing or development):
```sh
cd flash-attention
cd third_party/aiter && git fetch origin && git checkout <commit-sha> && cd ../..
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
```
To run the tests (note: full suite takes hours):
```sh
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```
The Triton backend uses a default kernel configuration optimized for determinism and reasonable performance across workloads. For peak throughput, enable `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` to search for optimal settings, which incurs a one-time warmup cost.
Alternativly, if _not_ autotuning, `FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON` may be used to set a single triton config overriding the hardcoded defaults for `attn_fwd`. E.g.
```sh
FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON='{"BLOCK_M":128,"BLOCK_N":64,"waves_per_eu":1,"PRE_LOAD_V":false,"num_stages":1,"num_warps":8}'
```
For a quick start with Docker:
```dockerfile
FROM rocm/pytorch:latest
WORKDIR /workspace
# build flash attention with triton backend
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
# set working dir
WORKDIR /workspace/flash-attention
# set env variable to use triton backend
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
```
Build and run:
```sh
docker build -t flash-attn-triton .
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton
```
## How to use FlashAttention
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
```python
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```
```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
```python
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
alibi_slopes=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
### Using with 🤗 Kernels
If your hardware environment belongs to any of the above-mentioned, you can also use the [`kernels` library](https://github.com/huggingface/kernels)
to use Flash Attention 2 and 3 right away.
```py
# pip install kernels
from kernels import get_kernel
# FA2
fa_module = get_kernel("kernels-community/flash-attn2", version=1)
flash_attn_func = fa_module.flash_attn_func
# FA3
fa3_module = get_kernel("kernels-community/flash-attn3", version=1)
flash_attn_func = fa3_module.flash_attn_func
```
## Changelog
### 2.0: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
### 2.1: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
### 2.2: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
cache as fast as possible, and we split the loading across different thread
blocks, with a separate kernel to combine results.
See the function `flash_attn_with_kvcache` with more features for inference
(perform rotary embedding, updating KV cache inplace).
Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.
### 2.3: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
### 2.5: Paged KV cache.
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
Thanks to @beginlner for this contribution.
### 2.6: Softcapping.
Support attention with softcapping, as used in Gemma-2 and Grok models.
Thanks to @Narsil and @lucidrains for this contribution.
### 2.7: Compatibility with torch compile
Thanks to @ani300 for this contribution.
## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
We currently have benchmarks for these GPUs:
* [A100](#a100)
* [H100](#h100)
<!-- * [RTX 3090](#rtx-3090) -->
<!-- * [T4](#t4) -->
### A100
We display FlashAttention speedup using these parameters:
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
* Batch size set to 16k / seqlen.
#### Speedup

#### Memory

We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.
### H100

## Full model code and training script
We have released the full GPT model
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 225
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
any activation checkpointing).
We also include a training
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
## Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
## Tests
We test that FlashAttention produces the same output and gradient as a reference
implementation, up to some numerical tolerance. In particular, we check that the
maximum numerical error of FlashAttention is at most twice the numerical error
of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal).
To run the tests:
```sh
pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues
This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue!
## Tests
To run the tests:
```sh
pytest tests/test_flash_attn_ck.py
```
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```
================================================
FILE: benchmarks/bench_sm90.py
================================================
#!/usr/bin/env python
"""Unified SM90 benchmark for forward and backward passes.
Usage:
# Default: bench fwd+bwd for hdim 64,96,128 at seqlen 8192
python benchmarks/bench_sm90.py
# Forward only, specific hdims
python benchmarks/bench_sm90.py --direction fwd --hdim 64,96
# Backward only
python benchmarks/bench_sm90.py --direction bwd --hdim 128
# Custom seqlens and batch size
python benchmarks/bench_sm90.py --seqlen 1024,2048,4096,8192 --batch 0
# Sweep tile sizes for fwd
python benchmarks/bench_sm90.py --sweep-tiles --hdim 96
# Sweep tile sizes for fwd (all hdims including 192, 256)
python benchmarks/bench_sm90.py --sweep-tiles --hdim 64,96,128,192,256
# Sweep RS/overlap variants
python benchmarks/bench_sm90.py --sweep-rs-overlap --hdim 64,96
# Compare old vs new configs
python benchmarks/bench_sm90.py --compare-configs
# Sweep backward optimizations (V_in_regs, mma_dkv_is_rs, pipeline sharing)
python benchmarks/bench_sm90.py --sweep-bwd-opts --hdim 64,128
# Causal only, more reps and warmup
python benchmarks/bench_sm90.py --causal-only --rep 50 --warmup 10
"""
import argparse
import time
import torch
import torch.nn.functional as F
from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd
# ── Helpers ────────────────────────────────────────────────────────────────
def parse_int_k(s):
"""Parse an integer with optional k/K suffix, e.g. '8k' -> 8192."""
s = s.strip().lower()
if s.endswith("k"):
return int(s[:-1]) * 1024
return int(s)
def csv_ints(s):
"""Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'."""
return [parse_int_k(x) for x in s.split(",")]
def parse_headdims(s):
"""Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v.
Examples:
'128' -> [(128, 128)]
'192-128' -> [(192, 128)]
'64,128,192' -> [(64, 64), (128, 128), (192, 192)]
'64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)]
"""
result = []
for item in s.split(","):
if "-" in item:
parts = item.split("-")
result.append((int(parts[0]), int(parts[1])))
else:
hdim = int(item)
result.append((hdim, hdim))
return result
def nheads_for_hdim(h):
return 32 if h <= 64 else (16 if h <= 192 else 8)
def fwd_flops(batch, nheads, seqlen, hdim, hdim_v=None, causal=False):
if hdim_v is None:
hdim_v = hdim
avg_seqlen = seqlen / 2 if causal else seqlen
return batch * nheads * 2 * seqlen * avg_seqlen * (hdim + hdim_v)
def bwd_flops(batch, nheads, seqlen, hdim, causal=False, hdim_v=None):
return 2.5 * fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal)
def get_causals(args):
if args.causal_only:
return [True]
if args.non_causal_only:
return [False]
return [False, True]
def auto_batch(seqlen, batch_arg, total_tokens=32768):
return batch_arg if batch_arg > 0 else max(1, total_tokens // seqlen)
# ── Core bench functions ──────────────────────────────────────────────────
def bench_fwd(batch, seqlen, nheads, hdim, causal, tile_m=None, tile_n=None,
mma_pv_is_rs=None, intra_wg_overlap=None, check_correctness=True,
warmup=5, rep=30, hdim_v=None):
"""Benchmark forward pass. Returns (ms, tflops, max_diff_or_error)."""
if hdim_v is None:
hdim_v = hdim
q = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device="cuda")
k = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device="cuda")
v = torch.randn(batch, seqlen, nheads, hdim_v, dtype=torch.bfloat16, device="cuda")
kwargs = dict(softmax_scale=hdim ** -0.5, causal=causal)
if tile_m is not None and tile_n is not None:
kwargs["tile_mn"] = (tile_m, tile_n)
if mma_pv_is_rs is not None:
kwargs["mma_pv_is_rs"] = mma_pv_is_rs
if intra_wg_overlap is not None:
kwargs["intra_wg_overlap"] = intra_wg_overlap
try:
out, _lse = _flash_attn_fwd(q, k, v, **kwargs)
except Exception as e:
return None, None, str(e)[:80]
max_diff = None
if check_correctness:
q_ref = q.transpose(1, 2).float()
k_ref = k.transpose(1, 2).float()
v_ref = v.transpose(1, 2).float()
out_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=causal)
out_ref = out_ref.transpose(1, 2).to(torch.bfloat16)
max_diff = (out.float() - out_ref.float()).abs().max().item()
for _ in range(warmup):
_flash_attn_fwd(q, k, v, **kwargs)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(rep):
_flash_attn_fwd(q, k, v, **kwargs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / rep
tflops = fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal) / ms / 1e9
return ms, tflops, max_diff
def bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=5, rep=30, hdim_v=None, **bwd_kwargs):
"""Benchmark backward pass. Returns (ms, tflops, None_or_error)."""
if hdim_v is None:
hdim_v = hdim
q = torch.randn(batch, seqlen, nheads, hdim, device="cuda", dtype=torch.bfloat16)
k = torch.randn(batch, seqlen, nheads, hdim, device="cuda", dtype=torch.bfloat16)
v = torch.randn(batch, seqlen, nheads, hdim_v, device="cuda", dtype=torch.bfloat16)
softmax_scale = hdim ** -0.5
try:
out, lse = _flash_attn_fwd(q, k, v, softmax_scale=softmax_scale, causal=causal,
return_lse=True)
except Exception as e:
return None, None, str(e)[:80]
dout = torch.randn_like(out)
def fn():
_flash_attn_bwd(q, k, v, out, dout, lse, softmax_scale=softmax_scale,
causal=causal, **bwd_kwargs)
try:
fn() # compile
except Exception as e:
return None, None, str(e)[:80]
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(rep):
fn()
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / rep
tflops = bwd_flops(batch, nheads, seqlen, hdim, causal, hdim_v=hdim_v) / ms / 1e9
return ms, tflops, None
# ── Preset configs ────────────────────────────────────────────────────────
# (tile_m, tile_n, mma_pv_is_rs, intra_wg_overlap)
TILE_SWEEP_CONFIGS = {
64: [
(192, 192, False, True),
(192, 192, True, True),
(192, 128, True, True),
(192, 128, False, True),
(128, 128, True, True),
(128, 192, True, True),
(192, 96, True, True),
(192, 96, False, True),
],
96: [
(192, 144, False, True),
(192, 144, True, True),
(192, 128, False, True),
(192, 128, True, True),
(192, 96, False, True),
(192, 96, True, True),
(128, 128, True, True),
(128, 128, False, True),
],
128: [
(128, 128, True, True),
(128, 128, False, True),
(128, 96, True, True),
(128, 96, False, True),
(128, 160, True, True),
(128, 176, True, True),
(128, 192, True, True),
],
192: [
(128, 64, True, True),
(128, 80, True, True),
(128, 96, True, True),
(128, 112, True, True),
(128, 128, True, True),
],
256: [
(128, 48, True, True),
(128, 64, True, True),
(128, 80, True, True),
(128, 96, True, True),
],
}
RS_OVERLAP_COMBOS = [
(True, True, "RS+OL"),
(True, False, "RS+noOL"),
(False, True, "noRS+OL"),
(False, False, "noRS+noOL"),
]
COMPARE_CONFIGS = [
# (hdim, causal, (old_tile_m, old_tile_n, old_rs, old_ol), (new...))
(64, False, (192, 128, True, True), (192, 128, True, True)),
(64, True, (192, 128, True, True), (192, 128, True, True)),
(96, False, (192, 96, True, True), (192, 144, False, True)),
(96, True, (192, 96, True, True), (192, 128, False, True)),
]
def _get_default_bwd_config(headdim, causal=False):
"""Default SM90 backward config for a given headdim."""
if headdim <= 128:
return dict(
m_block_size=64 if causal else 80,
n_block_size=128,
num_stages_Q=2,
num_stages_dO=2,
SdP_swapAB=True,
dKV_swapAB=False,
dQ_swapAB=not causal,
AtomLayoutMSdP=1,
AtomLayoutNdKV=2,
AtomLayoutMdQ=1,
num_threads=384,
)
elif headdim <= 192:
return dict(
m_block_size=64,
n_block_size=96,
num_stages_Q=1,
num_stages_dO=1,
SdP_swapAB=False,
dKV_swapAB=True,
dQ_swapAB=True,
AtomLayoutMSdP=1,
AtomLayoutNdKV=1,
AtomLayoutMdQ=1,
num_threads=512,
)
else:
return dict(
m_block_size=64,
n_block_size=64,
num_stages_Q=1,
num_stages_dO=1,
SdP_swapAB=False,
dKV_swapAB=False,
dQ_swapAB=False,
AtomLayoutMSdP=1,
AtomLayoutNdKV=1,
AtomLayoutMdQ=1,
num_threads=384,
)
# Maps optimization name -> function(headdim, causal) -> dict[label, kwargs] or None
BWD_OPT_CONFIGS = {
"V_in_regs": lambda hdim, causal: (
None if hdim > 128 else {
"baseline (V_in_regs=False)": {**_get_default_bwd_config(hdim, causal), "V_in_regs": False},
"optimized (V_in_regs=True)": {**_get_default_bwd_config(hdim, causal), "V_in_regs": True},
}
),
"mma_dkv_is_rs": lambda hdim, causal: (
None if hdim > 128 else {
"baseline (AtomLayoutNdKV=1)": {**_get_default_bwd_config(hdim, causal), "AtomLayoutNdKV": 1},
"optimized (AtomLayoutNdKV=2)": {**_get_default_bwd_config(hdim, causal), "AtomLayoutNdKV": 2},
}
),
"Q_dO_pipeline_sharing": lambda hdim, causal: (
None if hdim > 128 else {
"baseline (dO_stage=1, separate)": {**_get_default_bwd_config(hdim, causal), "num_stages_dO": 1},
"optimized (dO_stage=2, shared)": {**_get_default_bwd_config(hdim, causal), "num_stages_dO": 2},
}
),
"tile_m": lambda hdim, causal: (
None if hdim > 128 or causal else {
"tile_m=64": {**_get_default_bwd_config(hdim, causal), "m_block_size": 64},
"tile_m=80": {**_get_default_bwd_config(hdim, causal), "m_block_size": 80},
}
),
}
# ── Run modes ─────────────────────────────────────────────────────────────
def run_default(args):
"""Standard fwd/bwd benchmark across hdims."""
directions = [args.direction] if args.direction != "both" else ["fwd", "bwd"]
for direction in directions:
print(f"\n{'=' * 80}")
print(f" SM90 {direction.upper()} (rep={args.rep})")
print(f"{'=' * 80}")
cols = f"{'hdim':>5} {'hdim_v':>6} {'causal':>6} {'batch':>5} {'seqlen':>6} {'ms':>8} {'TFLOPS':>8}"
if direction == "fwd":
cols += f" {'max_diff':>10}"
print(cols)
print("-" * 80)
for hdim, hdim_v in args.hdim:
nheads = nheads_for_hdim(hdim)
for seqlen in args.seqlen:
batch = auto_batch(seqlen, args.batch)
for causal in get_causals(args):
if direction == "fwd":
ms, tflops, diff = bench_fwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)
else:
ms, tflops, diff = bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)
if ms is not None:
line = f"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {ms:>8.3f} {tflops:>8.1f}"
if diff is not None:
line += f" {diff:>10.6f}"
print(line)
else:
print(f"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {'FAIL':>8} {'':>8} {diff}")
def run_sweep_tiles(args):
"""Sweep tile sizes for fwd across seqlens."""
seqlens = args.seqlen
for hdim, hdim_v in args.hdim:
nheads = nheads_for_hdim(hdim)
configs = TILE_SWEEP_CONFIGS.get(hdim, [])
if not configs:
print(f"No tile sweep configs for hdim={hdim}, skipping")
continue
for causal in get_causals(args):
header = f"{'hdim':>5} {'causal':>6} {'tile_m':>6} {'tile_n':>6} {'pv_rs':>5} {'ol':>5}"
for sl in seqlens:
header += f" {'s=' + str(sl):>8}"
print(header)
print("=" * len(header))
for tile_m, tile_n, rs, ol in configs:
row = f"{hdim:>5} {str(causal):>6} {tile_m:>6} {tile_n:>6} {str(rs):>5} {str(ol):>5}"
for sl in seqlens:
batch = auto_batch(sl, args.batch)
ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal,
tile_m, tile_n, rs, ol,
check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)
row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}"
print(row)
print()
def run_sweep_rs_overlap(args):
"""Sweep RS and intra-WG-overlap combinations for fwd."""
seqlens = args.seqlen
tile_for_hdim = {64: (192, 128), 96: (192, 128), 128: (128, 128)}
for hdim, hdim_v in args.hdim:
nheads = nheads_for_hdim(hdim)
tile_m, tile_n = tile_for_hdim.get(hdim, (128, 128))
for causal in get_causals(args):
c_str = "causal" if causal else "non-causal"
header = f"{'Config':<30} {'RS/OL':<12}"
for sl in seqlens:
header += f" {'s=' + str(sl):>8}"
print(header)
print("=" * len(header))
for rs, ol, rs_label in RS_OVERLAP_COMBOS:
label = f"hdim{hdim} {c_str} {tile_m}x{tile_n}"
row = f"{label:<30} {rs_label:<12}"
for sl in seqlens:
batch = auto_batch(sl, args.batch)
ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal,
tile_m, tile_n, rs, ol,
check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)
row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}"
print(row)
print()
def run_compare_configs(args):
"""Compare old vs new tile configs for fwd."""
seqlens = args.seqlen
header = f"{'Config':<50}"
for sl in seqlens:
header += f" {'s=' + str(sl):>8}"
print(header)
print("=" * len(header))
for hdim, causal, old, new in COMPARE_CONFIGS:
nheads = nheads_for_hdim(hdim)
c_str = "causal" if causal else "non-causal"
for label_prefix, cfg in [("OLD", old), ("NEW", new)]:
label = f"hdim{hdim} {c_str:<11} {label_prefix} {cfg[0]}x{cfg[1]} RS={cfg[2]} OL={cfg[3]}"
row = f"{label:<50}"
for sl in seqlens:
batch = auto_batch(sl, args.batch)
ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal, *cfg,
check_correctness=False, warmup=args.warmup, rep=args.rep)
row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}"
print(row)
print("-" * len(header))
def run_sweep_bwd_opts(args):
"""Sweep backward kernel optimizations (V_in_regs, mma_dkv_is_rs, etc.)."""
seqlens = args.seqlen
for opt_name, get_configs_fn in BWD_OPT_CONFIGS.items():
for causal in get_causals(args):
c_str = "causal" if causal else "non-causal"
has_any = False
for hdim, hdim_v in args.hdim:
configs = get_configs_fn(hdim, causal)
if configs is None:
continue
if not has_any:
print(f"\n{'=' * 70}")
print(f"BWD Optimization: {opt_name} ({c_str})")
print(f"{'=' * 70}")
has_any = True
nheads = nheads_for_hdim(hdim)
print(f"\n hdim={hdim}:")
for sl in seqlens:
batch = auto_batch(sl, args.batch)
f = bwd_flops(batch, nheads, sl, hdim, causal, hdim_v=hdim_v)
if len(seqlens) > 1:
print(f" seqlen={sl}, batch={batch}:")
for label, kwargs in configs.items():
ms, tflops, err = bench_bwd(batch, sl, nheads, hdim, causal,
warmup=args.warmup, rep=args.rep, hdim_v=hdim_v, **kwargs)
if ms is not None:
print(f" {label:40s}: {ms:6.2f} ms ({tflops:6.1f} TFLOPS)")
else:
print(f" {label:40s}: FAIL {err}")
# ── Main ──────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Unified SM90 attention benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument("--direction", choices=["fwd", "bwd", "both"], default="both",
help="Benchmark direction (default: both)")
parser.add_argument("--hdim", type=parse_headdims, default=[(64, 64), (96, 96), (128, 128)],
help="Head dims, comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128")
parser.add_argument("--seqlen", type=csv_ints, default=[8192],
help="Sequence lengths, comma-separated (default: 8192)")
parser.add_argument("--batch", type=int, default=0,
help="Batch size (0 = auto ~32k tokens)")
parser.add_argument("--warmup", type=int, default=5,
help="Warmup iterations (default: 5)")
parser.add_argument("--rep", type=int, default=30,
help="Repetitions per benchmark (default: 30)")
parser.add_argument("--causal-only", action="store_true")
parser.add_argument("--non-causal-only", action="store_true")
mode = parser.add_mutually_exclusive_group()
mode.add_argument("--sweep-tiles", action="store_true",
help="Sweep fwd tile sizes")
mode.add_argument("--sweep-rs-overlap", action="store_true",
help="Sweep fwd RS/overlap combos")
mode.add_argument("--compare-configs", action="store_true",
help="Compare old vs new fwd tile configs")
mode.add_argument("--sweep-bwd-opts", action="store_true",
help="Sweep bwd optimizations (V_in_regs, mma_dkv_is_rs, etc.)")
args = parser.parse_args()
torch.manual_seed(0)
if args.sweep_tiles:
run_sweep_tiles(args)
elif args.sweep_rs_overlap:
run_sweep_rs_overlap(args)
elif args.compare_configs:
run_compare_configs(args)
elif args.sweep_bwd_opts:
run_sweep_bwd_opts(args)
else:
run_default(args)
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/benchmark_alibi.py
================================================
# Copyright (c) 2024, Sanghun Cho, Tri Dao.
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
try:
import xformers.ops as xops
except ImportError:
xops = None
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos, sin
def flash_rotary(q, k, v, cos, sin, causal=False):
# corrected by @tridao comments
q = apply_rotary_emb(
q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
)
k = apply_rotary_emb(
k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
)
return flash_attn_func(q, k, v, causal=causal)
def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
):
batch, nheads = slopes.shape
device = slopes.device
slopes = rearrange(slopes, "b h -> b h 1 1")
if causal:
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
else:
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
return -slopes * relative_pos.to(dtype=slopes.dtype)
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
"""
Arguments:
q, k, v: (batch_size, seqlen, nheads, head_dim)
dropout_p: float
attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, nheads, d = q.shape
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
if attn_bias is not None:
scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
else:
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=q.dtype)
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
repeats = 30
device = 'cuda'
dtype = torch.float16
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
methods = (["fa2_alibi", "torch"]
+ (["xformers"] if xops is not None else [])
+ ["sdpa"]
+ ["fa2_baseline"]
+ ["fa2_rotary"])
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
f, b = time_fwd_bwd(
flash_attn_func,
q, k, v,
dropout_p,
causal=causal,
# alibi_slopes=alibi_slopes,
alibi_slopes=None,
repeats=repeats,
verbose=False
)
time_f[config, "fa2_baseline"] = f
time_b[config, "fa2_baseline"] = b
q = q.detach().requires_grad_(True)
k = k.detach().requires_grad_(True)
v = v.detach().requires_grad_(True)
f, b = time_fwd_bwd(
flash_attn_func,
q, k, v,
dropout_p,
causal=causal,
alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
# alibi_slopes=None,
repeats=repeats,
verbose=False
)
time_f[config, "fa2_alibi"] = f
time_b[config, "fa2_alibi"] = b
try:
q = q.detach().requires_grad_(True)
k = k.detach().requires_grad_(True)
v = v.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch,
q, k, v,
dropout_p,
causal=causal,
attn_bias=attn_bias,
repeats=repeats,
verbose=False
)
except: # Skip if OOM
f, b = float('nan'), float('nan')
time_f[config, "torch"] = f
time_b[config, "torch"] = b
# F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
with torch.backends.cuda.sdp_kernel(enable_flash=False):
q_pt = q.detach().requires_grad_(True).transpose(1, 2)
k_pt = k.detach().requires_grad_(True).transpose(1, 2)
v_pt = v.detach().requires_grad_(True).transpose(1, 2)
f, b = time_fwd_bwd(
F.scaled_dot_product_attention,
q_pt, k_pt, v_pt,
attn_mask=attn_bias,
dropout_p=dropout_p,
is_causal=causal,
repeats=repeats,
verbose=False
)
time_f[config, "sdpa"] = f
time_b[config, "sdpa"] = b
if xops is not None:
q = q.detach().requires_grad_(True)
k = k.detach().requires_grad_(True)
v = v.detach().requires_grad_(True)
if causal:
attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
# NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
# `flshattB@v2.3.6` is not supported because:
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
# `cutlassB` is not supported because:
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
else:
attn_bias_xops = attn_bias.to(dtype=q.dtype)
f, b = time_fwd_bwd(
xops.memory_efficient_attention,
q, k, v,
attn_bias_xops,
dropout_p,
repeats=repeats,
verbose=False
)
time_f[config, "xformers"] = f
time_b[config, "xformers"] = b
q = q.detach().requires_grad_(True)
k = k.detach().requires_grad_(True)
v = v.detach().requires_grad_(True)
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
f, b = time_fwd_bwd(
flash_rotary,
q, k, v,
cos, sin,
causal,
repeats=repeats,
verbose=False
)
time_f[config, "fa2_rotary"] = f
time_b[config, "fa2_rotary"] = b
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
csv_output = ""
csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
for method in methods:
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
)
csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
print(csv_output)
================================================
FILE: benchmarks/benchmark_attn.py
================================================
import argparse
import time
import torch
try:
import cudnn
except ImportError:
cudnn = None
from einops import rearrange
from flash_attn.cute.bench_utils import (
flops, attention_ref,
cudnn_fwd_setup, cudnn_bwd_setup,
)
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
except ImportError:
flash_attn_func = None
flash_attn_varlen_func = None
try:
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python
except ImportError:
flash_attn_func_python = None
flash_attn_varlen_func_python = None
try:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability()[0] != 9:
flash_attn_func_v3 = None
from triton.testing import do_bench
# ── Autograd backward helper ────────────────────────────────────────────────
def _make_bwd_fn(fwd_fn, g, inputs):
"""Run fwd once, return a closure that benchmarks backward.
Args:
fwd_fn: zero-arg callable that runs the forward pass (with autograd).
g: gradient tensor (b, seqlen, nheads, headdim_v).
inputs: list of input tensors whose .grad should be cleared each iteration.
"""
out = fwd_fn()
if isinstance(out, tuple):
out = out[0]
g_match = g[:out.shape[0]] if g.shape[0] != out.shape[0] else g # handle varlen
def bwd_fn():
for x in inputs:
x.grad = None
out.backward(g_match, retain_graph=True)
return bwd_fn
# ── Backend definitions ─────────────────────────────────────────────────────
# Each setup_* function takes a context dict and returns (fwd_fn, bwd_fn).
# Either can be None if the backend doesn't support that direction for the
# given config. fwd_fn / bwd_fn are zero-arg callables suitable for do_bench.
def setup_standard(ctx):
if ctx["dtype"] == torch.float8_e4m3fn:
return None, None
q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"]
fwd_fn = lambda: attention_ref(q, k, v, causal=causal)
bwd_fn = _make_bwd_fn(fwd_fn, g, [q, k, v]) if ctx["has_backward"] else None
return fwd_fn, bwd_fn
def setup_fa2(ctx):
if flash_attn_func is None or ctx["dtype"] == torch.float8_e4m3fn:
return None, None
if ctx["headdim"] != ctx["headdim_v"]:
return None, None
q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"]
dropout_p, window_size_fa, softcap = ctx["dropout_p"], ctx["window_size_fa"], ctx["softcap"]
deterministic = ctx["deterministic"]
if ctx["varlen"]:
qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"]
csq, csk, sq, sk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"], ctx["seqlen_q"], ctx["seqlen"]
fwd_fn = lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap)
bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) if ctx["has_backward"] else None
else:
fwd_fn = lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap)
bwd_fn = _make_bwd_fn(lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [q, k, v]) if ctx["has_backward"] else None
return fwd_fn, bwd_fn
def setup_cudnn(ctx):
if cudnn is None or ctx["headdim"] > 256 or ctx["dtype"] == torch.float8_e4m3fn:
return None, None
q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"]
window_size_left = ctx["window_size"][0]
# cuDNN expects (batch, nheads, seqlen, headdim) layout
qt, kt, vt, gt = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), g.transpose(1, 2)
fwd_fn, o_gpu, lse_gpu = cudnn_fwd_setup(qt, kt, vt, causal=causal, window_size_left=window_size_left)
bwd_fn = None
if ctx["has_backward"]:
fwd_fn() # populate o and lse for bwd graph
bwd_fn = cudnn_bwd_setup(qt, kt, vt, o_gpu, gt, lse_gpu, causal=causal, window_size_left=window_size_left)
return fwd_fn, bwd_fn
def setup_fa3(ctx):
if flash_attn_func_v3 is None:
return None, None
q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"]
window_size_fa, softcap = ctx["window_size_fa"], ctx["softcap"]
num_splits, pack_gqa, deterministic = ctx["num_splits"], ctx["pack_gqa"], ctx["deterministic"]
k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k
v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v
if ctx["varlen"]:
qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"]
csq, csk, sq, sk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"], ctx["seqlen_q"], ctx["seqlen"]
fwd_fn = lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
else:
fwd_fn = lambda: flash_attn_func_v3(q, k_use, v_use, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
# FA3 bwd only supports headdim == headdim_v and non-fp8
bwd_fn = None
if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn and ctx["headdim"] == ctx["headdim_v"]:
if ctx["varlen"]:
bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=ctx["window_size"], softcap=softcap, deterministic=deterministic), g, [qu, ku, vu])
else:
bwd_fn = _make_bwd_fn(lambda: flash_attn_func_v3(q, k, v, causal=causal, softcap=softcap), g, [q, k, v])
return fwd_fn, bwd_fn
def setup_fa4(ctx):
if flash_attn_func_python is None:
return None, None
q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"]
window_size, softcap = ctx["window_size"], ctx["softcap"]
pack_gqa, deterministic = ctx["pack_gqa"], ctx["deterministic"]
sinks = ctx["sinks"]
k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k
v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v
if ctx["varlen"]:
qu = ctx["q_unpad"]
ku = ctx.get("k_paged", ctx["k_unpad"]) if ctx["page_size"] is not None else ctx["k_unpad"]
vu = ctx.get("v_paged", ctx["v_unpad"]) if ctx["page_size"] is not None else ctx["v_unpad"]
csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"]
pt = ctx["page_table"]
fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa)
else:
fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa)
bwd_fn = None
if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn:
if ctx["varlen"]:
qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"]
csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"]
bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu])
else:
bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v])
return fwd_fn, bwd_fn
# Ordered list of (display_name, cli_name, setup_fn)
BACKENDS = [
("Standard", "standard", setup_standard),
("FA2", "fa2", setup_fa2),
("cuDNN", "cudnn", setup_cudnn),
("FA3", "fa3", setup_fa3),
("FA4", "fa4", setup_fa4),
]
def parse_int_k(s):
"""Parse an integer with optional k/K suffix, e.g. '8k' -> 8192."""
s = s.strip().lower()
if s.endswith("k"):
return int(s[:-1]) * 1024
return int(s)
def csv_ints(s):
"""Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'."""
return [parse_int_k(x) for x in s.split(",")]
def parse_headdims(s):
"""Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v.
Examples:
'128' -> [(128, 128)]
'192-128' -> [(192, 128)]
'64,128,192' -> [(64, 64), (128, 128), (192, 192)]
'64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)]
"""
result = []
for item in s.split(","):
if "-" in item:
parts = item.split("-")
result.append((int(parts[0]), int(parts[1])))
else:
hdim = int(item)
result.append((hdim, hdim))
return result
def csv_strs(s):
"""Parse comma-separated strings, e.g. 'fa3,fa4'."""
return [x.strip() for x in s.split(",")]
def parse_args():
parser = argparse.ArgumentParser(description='Benchmark FlashAttention')
parser.add_argument('--headdim', type=parse_headdims, default=[(128, 128)],
help='Head dim(s), comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128')
parser.add_argument('--fwd', action='store_true', help='Run forward only')
parser.add_argument('--bwd', action='store_true', help='Run backward only')
parser.add_argument('--varlen', action='store_true', default=False)
parser.add_argument('--causal', type=str.lower, choices=['true', 'false', 'both'], default='both',
help='Causal mode (default: both)')
parser.add_argument('--seqlen', type=csv_ints, default=[8192],
help='Sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k')
parser.add_argument('--total-seqlen', type=parse_int_k, default='32k',
help='Total sequence length for batch sizing (default: 32k)')
parser.add_argument('--batch-size', type=int, default=None,
help='Batch size (default: total_seqlen // seqlen)')
parser.add_argument('--deterministic', action='store_true', default=False)
parser.add_argument('--nheads', type=int, default=None,
help='Number of Q heads (default: 32 for hdim<=64, 16 for hdim<=192, 8 for hdim>192)')
parser.add_argument('--nheads-kv', type=int, default=None,
help='Number of KV heads (default: nheads)')
parser.add_argument('--gqa-ratio', type=int, default=None,
help='GQA ratio (nheads // nheads_kv). Ignored if --nheads-kv is set.')
parser.add_argument('--backend', type=csv_strs, default=['all'],
help='Which backends to benchmark, comma-separated (choices: all,standard,fa2,fa3,fa4,cudnn)')
parser.add_argument('--warmup', type=int, default=5,
help='Warmup iterations (default: 5)')
parser.add_argument('--rep', type=int, default=10,
help='Repetitions per benchmark (default: 10)')
return parser.parse_args()
def main():
args = parse_args()
headdim_pairs = args.headdim # list of (hdim, hdim_v) tuples
# Parse fwd/bwd: if neither specified, do fwd only
has_forward = args.fwd or not args.bwd
has_backward = args.bwd
# Parse causal
if args.causal == 'true':
causal_vals = [True]
elif args.causal == 'false':
causal_vals = [False]
else:
causal_vals = [False, True]
seqlen_list = args.seqlen
varlen = args.varlen
# Filter backends to those requested and available
enabled = set(args.backend)
if 'all' in enabled:
enabled = {cli for _, cli, _ in BACKENDS}
active_backends = [(name, cli, fn) for name, cli, fn in BACKENDS if cli in enabled]
# Parameters
torch.manual_seed(0)
dropout_p = 0.0
dtype = torch.bfloat16
dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
device = 'cuda'
page_size = None
softcap = 0.0
deterministic = args.deterministic
warmup, rep = args.warmup, args.rep
time_f = {}
time_b = {}
for headdim, headdim_v in headdim_pairs:
nheads = args.nheads if args.nheads is not None else (32 if headdim <= 64 else 16 if headdim <= 192 else 8)
if args.nheads_kv is not None:
nheads_kv = args.nheads_kv
elif args.gqa_ratio is not None:
nheads_kv = nheads // args.gqa_ratio
else:
nheads_kv = nheads
has_qv = headdim == 64 and headdim_v == 512
sinks = None
num_splits = 0
window_size = (None, None)
window_size_fa = (-1, -1)
pack_gqa = None
for seqlen in seqlen_list:
batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen)
seqlen_q = seqlen
q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward)
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward)
v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward)
q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]]
g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen)
# Varlen tensors
q_unpad = k_unpad = v_unpad = cu_seqlens_q = cu_seqlens_k = None
if varlen:
q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]]
cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None
# Paged KV tensors
k_paged = v_paged = page_table = None
if page_size is not None:
assert seqlen % page_size == 0
k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]]
page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),
"(b s) -> b s", s=seqlen // page_size)
for causal in causal_vals:
cfg = (headdim, headdim_v, causal, seqlen, batch_size, nheads)
# Build context dict shared by all backends
ctx = dict(
q=q, k=k, v=v, g=g, causal=causal,
headdim=headdim, headdim_v=headdim_v, dtype=dtype,
has_backward=has_backward,
varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
seqlen_q=seqlen_q, seqlen=seqlen,
page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table,
dropout_p=dropout_p, window_size=window_size, window_size_fa=window_size_fa,
softcap=softcap, deterministic=deterministic,
num_splits=num_splits, pack_gqa=pack_gqa, sinks=sinks,
)
for display_name, cli_name, setup_fn in active_backends:
fwd_fn, bwd_fn = setup_fn(ctx)
if fwd_fn is not None and has_forward:
time.sleep(1.0)
print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}")
ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3
time_f[cfg, display_name] = ms
if bwd_fn is not None and has_backward:
time.sleep(1.0)
print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}")
ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3
time_b[cfg, display_name] = ms
# ── Print results table ──────────────────────────────────────────────────
backend_names = [name for name, _, _ in BACKENDS]
shown_backends = [b for b in backend_names if any(b == k[1] for k in list(time_f) + list(time_b))]
if not shown_backends:
return
col_w = 16
for direction, times, flops_mult in [("FWD", time_f, 1.0), ("BWD", time_b, 2.5)]:
if not times:
continue
configs = sorted(set(k[0] for k in times))
if not configs:
continue
header = f"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}"
for b in shown_backends:
header += f" {b:>{col_w}}"
print(f"\n{'=' * len(header)}")
print(f" {direction} (ms / TFLOPS)")
print(f"{'=' * len(header)}")
print(header)
print("-" * len(header))
for cfg in configs:
headdim, headdim_v, causal, seqlen, batch_size, nheads = cfg
nFLOPS = flops(batch_size, nheads, seqlen, seqlen, headdim, headdim_v, causal=causal)
hdim_str = str(headdim) if headdim == headdim_v else f"{headdim}-{headdim_v}"
row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5} {seqlen:>6}"
for b in shown_backends:
t = times.get((cfg, b))
if t is not None:
tflops = flops_mult * nFLOPS / t * 1e-12
ms = t * 1e3
cell = f"{ms:.2f}/{tflops:.0f}"
row += f" {cell:>{col_w}}"
else:
row += f" {'—':>{col_w}}"
print(row)
if __name__ == '__main__':
main()
================================================
FILE: benchmarks/benchmark_causal.py
================================================
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# # from flash_attn.triton.fused_attention import attention as attention
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
# from flash_attn.flash_attn_triton_og import attention as attention_og
# from triton.ops.flash_attention import attention as attention_triton
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)
torch.manual_seed(0)
repeats = 30
batch_size = 8
seqlen = 2048
nheads = 12
headdim = 128
# nheads = 24
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p = 0.0
causal = True
dtype = torch.float16
device = 'cuda'
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
# requires_grad=True)
# if fav2_kvpacked_func is not None:
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
# dropout_p = 0.0
# causal = False
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
# repeats=repeats, desc='PyTorch Attention')
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
# from src.ops.fftconv import fftconv_func
# dim = nheads * headdim
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
# D = torch.randn(dim, device=device, requires_grad=True)
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
# pytorch_profiler(torch.fft.rfft, u.float())
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
exit(0)
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
time_f = {}
time_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
f, b = time_fwd_bwd(
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# # Try both values of sequence_parallel and pick the faster one
# f, b = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# False, repeats=repeats, verbose=False
# )
# _, b0 = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# True, repeats=repeats, verbose=False
# )
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
if seqlen <= 8 * 1024:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
else:
f, b = float('nan'), float('nan')
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# import xformers.ops as xops
# f, b = time_fwd_bwd(
# xops.memory_efficient_attention, q, k, v,
# attn_bias=xops.LowerTriangularMask() if causal else None,
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
# )
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
import pickle
with open('flash2_attn_time_h100.plk', 'wb') as fp:
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
================================================
FILE: benchmarks/benchmark_flash_attention.py
================================================
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
from flash_attn import flash_attn_qkvpacked_func
try:
from triton.ops.flash_attention import attention as attention_triton
except ImportError:
attention_triton = None
try:
import xformers.ops as xops
except ImportError:
xops = None
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# Adding is faster than masked_fill_
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
repeats = 30
device = 'cuda'
dtype = torch.float16
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
methods = (["Flash2", "Pytorch"]
+ (["Triton"] if attention_triton is not None else [])
+ (["xformers.c"] if xops is not None else [])
+ (["xformers.f"] if xops is not None else []))
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
# FlashAttention 2
if "Flash2" in methods:
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal,
repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b
# PyTorch baseline
if "Pytorch" in methods:
try:
# fresh tensor avoids grad-history reuse issues
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal,
repeats=repeats, verbose=False
)
except Exception:
f, b = float('nan'), float('nan')
time_f[config, "Pytorch"] = f
time_b[config, "Pytorch"] = b
# Triton
if "Triton" in methods and attention_triton is not None:
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
# Try both values of sequence_parallel and pick the faster backward
try:
f, b = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
False, repeats=repeats, verbose=False
)
except Exception:
f, b = float('nan'), float('inf')
try:
_, b0 = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
True, repeats=repeats, verbose=False
)
except Exception:
b0 = float('inf')
time_f[config, "Triton"] = f
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
# xFormers CUTLASS
if "xformers.c" in methods and xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
)
time_f[config, "xformers.c"] = f
time_b[config, "xformers.c"] = b
# xFormers Flash
if "xformers.f" in methods and xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
)
time_f[config, "xformers.f"] = f
time_b[config, "xformers.f"] = b
# Report
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
if (config, method) not in time_f or (config, method) not in time_b:
continue
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
)
# with open('flash2_attn_time.plk', 'wb') as fp:
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
================================================
FILE: benchmarks/benchmark_gemm.py
================================================
import time
import torch
import torch.utils.benchmark as benchmark
from triton.testing import do_bench
if torch.version.cuda:
backendBLAS = "cuBLAS"
elif torch.version.hip:
backendBLAS = "hipBLAS"
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, '- Forward pass')
t = benchmark.Timer(
stmt='fn(*inputs, **kwinputs)',
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
torch.manual_seed(0)
repeats = 30
dtype = torch.bfloat16
device = 'cuda'
verbose = False
m, n = 8192, 8192
tflops_matmul = {}
tflops_matmul1 = {}
for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
a = torch.randn(m, k, device=device, dtype=dtype)
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
nFLOPS_matmul = 2 * m * n * k
time.sleep(2) # to reduce power throttling
timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1]
tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
time.sleep(2) # to reduce power throttling
ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')
================================================
FILE: csrc/flash_attn/flash_api.cpp
================================================
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
#include "philox_unpack.cuh" // For at::cuda::philox::unpack
#include <cutlass/numeric_types.h>
#include "namespace_config.h"
#include "hardware_info.h"
#include "flash.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
namespace FLASH_NAMESPACE {
void set_params_fprop(Flash_fwd_params ¶ms,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
// Reset the parameters
params = {};
params.is_bf16 = q.dtype() == torch::kBFloat16;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.o_ptr = out.data_ptr();
params.o_row_stride = out.stride(-3);
params.o_head_stride = out.stride(-2);
if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = q.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
params.o_batch_stride *= seqlen_q;
}
}
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k);
// P = softmax(QK^T)
params.p_ptr = p_d;
// Softmax sum
params.softmax_lse_ptr = softmax_lse_d;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
// Set the different scale values.
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
#endif
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead of <
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f);
#ifdef FLASHATTENTION_DISABLE_DROPOUT
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
#endif
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
"This flash attention build does not support local attention.");
#endif
params.is_seqlens_k_cumulative = true;
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif
params.unpadded_lse = unpadded_lse;
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
}
void set_params_dgrad(Flash_bwd_params ¶ms,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor out,
const at::Tensor dout,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *dq_accum_d,
void *dk_accum_d,
void *dv_accum_d,
void *softmax_lse_d,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right,
const float softcap,
bool deterministic,
const bool unpadded_lse) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
q, k, v, out,
cu_seqlens_q_d,
cu_seqlens_k_d,
nullptr,
nullptr,
softmax_lse_d,
p_dropout,
softmax_scale,
window_size_left,
window_size_right,
softcap,
false, // seqlenq_ngroups_swapped
unpadded_lse);
// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
params.do_row_stride = dout.stride(-3);
params.do_head_stride = dout.stride(-2);
params.dq_ptr = dq.data_ptr();
params.dk_ptr = dk.data_ptr();
params.dv_ptr = dv.data_ptr();
params.dq_row_stride = dq.stride(-3);
params.dk_row_stride = dk.stride(-3);
params.dv_row_stride = dv.stride(-3);
params.dq_head_stride = dq.stride(-2);
params.dk_head_stride = dk.stride(-2);
params.dv_head_stride = dv.stride(-2);
if (cu_seqlens_q_d == nullptr) {
params.do_batch_stride = dout.stride(0);
params.dq_batch_stride = dq.stride(0);
params.dk_batch_stride = dk.stride(0);
params.dv_batch_stride = dv.stride(0);
}
params.dq_accum_ptr = dq_accum_d;
params.dk_accum_ptr = dk_accum_d;
params.dv_accum_ptr = dv_accum_d;
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;
params.deterministic = deterministic;
}
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}
});
});
});
}
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 85%
// of the best efficiency.
inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, const float p_dropout,
const int num_splits, const int num_sm, struct c10::TensorOptions opts) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
at::Tensor softmax_lse_accum;
at::Tensor out_accum;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
return std::make_tuple(softmax_lse_accum, out_accum);
}
void set_params_alibi(Flash_fwd_params ¶ms, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
#ifdef FLASHATTENTION_DISABLE_ALIBI
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
params.alibi_slopes_ptr = nullptr;
#else
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else {
params.alibi_slopes_ptr = nullptr;
}
#endif
}
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
}
} else {
out = torch::empty_like(q);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
auto opts = q.options();
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
else {
p = torch::empty({ 0 }, opts);
}
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
window_size_left,
window_size_right,
softcap
);
// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, softmax_lse, p, rng_state};
}
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
cu_seqlens_q_d = nullptr;
}
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size);
if (!paged_KV) {
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
}
} else {
out = torch::empty_like(q);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
else {
p = torch::empty({ 0 }, opts);
}
if (zero_tensors) {
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) {p.zero_();}
}
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
window_size_left,
window_size_right,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
params.total_q = total_q;
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
}
params.page_block_size = page_block_size;
// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
std::tie(softmax_lse_accum, out_accum) =
set_params_splitkv(params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
}
if (leftpad_k_.has_value()) {
auto leftpad_k = leftpad_k_.value();
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
CHECK_DEVICE(leftpad_k);
CHECK_CONTIGUOUS(leftpad_k);
CHECK_SHAPE(leftpad_k, batch_size);
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
}
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, paged_KV);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
if (seqlenq_ngroups_swapped) {
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
}
return {out, softmax_lse, p, rng_state};
}
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
});
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dv = torch::empty_like(v);
}
// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
bool loop = true;
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
if (!deterministic) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
}
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
Flash_bwd_params params;
set_params_dgrad(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
dout, dq, dk_expanded, dv_expanded,
nullptr,
nullptr,
loop ? dq_accum.data_ptr() : nullptr,
// loop ? dk_accum.data_ptr() : nullptr,
// loop ? dv_accum.data_ptr() : nullptr,
nullptr,
nullptr,
softmax_lse.data_ptr(),
softmax_d.data_ptr(),
p_dropout,
gitextract_wgwkfssl/ ├── .github/ │ └── workflows/ │ ├── README.md │ ├── _build.yml │ ├── build.yml │ ├── pre-commit.yaml │ ├── publish-fa4.yml │ └── publish.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── AI/ │ ├── DEBUG_2CTA.md │ ├── RACECHECK_TMA_HAZARD.md │ ├── SM90_BLOCK_SIZE_TUNING.md │ ├── SM90_R2P_MASKING_SASS.md │ ├── VARLEN_PREPROCESS_TILE_BUG.md │ ├── racecheck_repro_1d_bulk.py │ └── racecheck_repro_1d_tensor.py ├── AUTHORS ├── CLAUDE.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── benchmarks/ │ ├── bench_sm90.py │ ├── benchmark_alibi.py │ ├── benchmark_attn.py │ ├── benchmark_causal.py │ ├── benchmark_flash_attention.py │ └── benchmark_gemm.py ├── csrc/ │ ├── flash_attn/ │ │ ├── flash_api.cpp │ │ └── src/ │ │ ├── alibi.h │ │ ├── block_info.h │ │ ├── dropout.h │ │ ├── flash.h │ │ ├── flash_bwd_hdim128_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_sm80.cu │ │ ├── flash_bwd_hdim32_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim32_bf16_sm80.cu │ │ ├── flash_bwd_hdim32_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim32_fp16_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_sm80.cu │ │ ├── flash_bwd_kernel.h │ │ ├── flash_bwd_launch_template.h │ │ ├── flash_bwd_preprocess_kernel.h │ │ ├── flash_fwd_hdim128_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_hdim32_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim32_bf16_sm80.cu │ │ ├── flash_fwd_hdim32_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim32_fp16_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm80.cu │ │ ├── flash_fwd_kernel.h │ │ ├── flash_fwd_launch_template.h │ │ ├── flash_fwd_split_hdim128_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim128_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim192_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim192_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim256_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim256_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim32_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim32_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim32_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim32_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim64_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim64_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim96_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim96_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim96_fp16_sm80.cu │ │ ├── generate_kernels.py │ │ ├── hardware_info.h │ │ ├── kernel_traits.h │ │ ├── mask.h │ │ ├── namespace_config.h │ │ ├── philox.cuh │ │ ├── philox_unpack.cuh │ │ ├── rotary.h │ │ ├── softmax.h │ │ ├── static_switch.h │ │ └── utils.h │ ├── flash_attn_ck/ │ │ ├── flash_api.cpp │ │ ├── flash_common.cpp │ │ ├── flash_common.hpp │ │ ├── mha_bwd.cpp │ │ ├── mha_fwd.cpp │ │ ├── mha_fwd_kvcache.cpp │ │ ├── mha_varlen_bwd.cpp │ │ └── mha_varlen_fwd.cpp │ ├── fused_dense_lib/ │ │ ├── README.md │ │ ├── fused_dense.cpp │ │ ├── fused_dense_cuda.cu │ │ └── setup.py │ └── layer_norm/ │ ├── README.md │ ├── ln.h │ ├── ln_api.cpp │ ├── ln_bwd_1024.cu │ ├── ln_bwd_1280.cu │ ├── ln_bwd_1536.cu │ ├── ln_bwd_2048.cu │ ├── ln_bwd_256.cu │ ├── ln_bwd_2560.cu │ ├── ln_bwd_3072.cu │ ├── ln_bwd_4096.cu │ ├── ln_bwd_512.cu │ ├── ln_bwd_5120.cu │ ├── ln_bwd_6144.cu │ ├── ln_bwd_7168.cu │ ├── ln_bwd_768.cu │ ├── ln_bwd_8192.cu │ ├── ln_bwd_kernels.cuh │ ├── ln_fwd_1024.cu │ ├── ln_fwd_1280.cu │ ├── ln_fwd_1536.cu │ ├── ln_fwd_2048.cu │ ├── ln_fwd_256.cu │ ├── ln_fwd_2560.cu │ ├── ln_fwd_3072.cu │ ├── ln_fwd_4096.cu │ ├── ln_fwd_512.cu │ ├── ln_fwd_5120.cu │ ├── ln_fwd_6144.cu │ ├── ln_fwd_7168.cu │ ├── ln_fwd_768.cu │ ├── ln_fwd_8192.cu │ ├── ln_fwd_kernels.cuh │ ├── ln_kernel_traits.h │ ├── ln_parallel_bwd_1024.cu │ ├── ln_parallel_bwd_1280.cu │ ├── ln_parallel_bwd_1536.cu │ ├── ln_parallel_bwd_2048.cu │ ├── ln_parallel_bwd_256.cu │ ├── ln_parallel_bwd_2560.cu │ ├── ln_parallel_bwd_3072.cu │ ├── ln_parallel_bwd_4096.cu │ ├── ln_parallel_bwd_512.cu │ ├── ln_parallel_bwd_5120.cu │ ├── ln_parallel_bwd_6144.cu │ ├── ln_parallel_bwd_7168.cu │ ├── ln_parallel_bwd_768.cu │ ├── ln_parallel_bwd_8192.cu │ ├── ln_parallel_fwd_1024.cu │ ├── ln_parallel_fwd_1280.cu │ ├── ln_parallel_fwd_1536.cu │ ├── ln_parallel_fwd_2048.cu │ ├── ln_parallel_fwd_256.cu │ ├── ln_parallel_fwd_2560.cu │ ├── ln_parallel_fwd_3072.cu │ ├── ln_parallel_fwd_4096.cu │ ├── ln_parallel_fwd_512.cu │ ├── ln_parallel_fwd_5120.cu │ ├── ln_parallel_fwd_6144.cu │ ├── ln_parallel_fwd_7168.cu │ ├── ln_parallel_fwd_768.cu │ ├── ln_parallel_fwd_8192.cu │ ├── ln_parallel_residual_bwd_kernels.cuh │ ├── ln_parallel_residual_fwd_kernels.cuh │ ├── ln_utils.cuh │ ├── setup.py │ └── static_switch.h ├── examples/ │ └── inference/ │ └── README.md ├── flash_attn/ │ ├── __init__.py │ ├── bert_padding.py │ ├── cute/ │ │ ├── .flake8 │ │ ├── AUTHORS │ │ ├── LICENSE │ │ ├── MANIFEST.in │ │ ├── README.md │ │ ├── __init__.py │ │ ├── ampere_helpers.py │ │ ├── barrier.py │ │ ├── bench_utils.py │ │ ├── benchmark.py │ │ ├── blackwell_helpers.py │ │ ├── block_info.py │ │ ├── block_sparse_utils.py │ │ ├── block_sparsity.py │ │ ├── cache_utils.py │ │ ├── compute_block_sparsity.py │ │ ├── copy_utils.py │ │ ├── cute_dsl_ptxas.py │ │ ├── cute_dsl_utils.py │ │ ├── fa_logging.py │ │ ├── fast_math.py │ │ ├── flash_bwd.py │ │ ├── flash_bwd_postprocess.py │ │ ├── flash_bwd_preprocess.py │ │ ├── flash_bwd_sm100.py │ │ ├── flash_bwd_sm120.py │ │ ├── flash_bwd_sm90.py │ │ ├── flash_fwd.py │ │ ├── flash_fwd_combine.py │ │ ├── flash_fwd_sm100.py │ │ ├── flash_fwd_sm120.py │ │ ├── flash_fwd_sm90.py │ │ ├── interface.py │ │ ├── mask.py │ │ ├── mma_sm100_desc.py │ │ ├── named_barrier.py │ │ ├── pack_gqa.py │ │ ├── paged_kv.py │ │ ├── pipeline.py │ │ ├── pyproject.toml │ │ ├── seqlen_info.py │ │ ├── sm90_config_search.py │ │ ├── softmax.py │ │ ├── testing.py │ │ ├── tile_scheduler.py │ │ └── utils.py │ ├── flash_attn_interface.py │ ├── flash_attn_triton.py │ ├── flash_attn_triton_og.py │ ├── flash_blocksparse_attention.py │ ├── flash_blocksparse_attn_interface.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── patch_embed.py │ │ └── rotary.py │ ├── losses/ │ │ ├── __init__.py │ │ └── cross_entropy.py │ ├── models/ │ │ ├── __init__.py │ │ ├── baichuan.py │ │ ├── bert.py │ │ ├── bigcode.py │ │ ├── btlm.py │ │ ├── falcon.py │ │ ├── gpt.py │ │ ├── gpt_neox.py │ │ ├── gptj.py │ │ ├── llama.py │ │ ├── opt.py │ │ └── vit.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── block.py │ │ ├── embedding.py │ │ ├── mha.py │ │ └── mlp.py │ ├── ops/ │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── fused_dense.py │ │ ├── layer_norm.py │ │ ├── rms_norm.py │ │ └── triton/ │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ ├── k_activations.py │ │ ├── layer_norm.py │ │ ├── linear.py │ │ ├── mlp.py │ │ └── rotary.py │ ├── pyproject.toml │ └── utils/ │ ├── __init__.py │ ├── benchmark.py │ ├── distributed.py │ ├── generation.py │ ├── library.py │ ├── pretrained.py │ ├── testing.py │ └── torch.py ├── hopper/ │ ├── __init__.py │ ├── benchmark_attn.py │ ├── benchmark_flash_attention_fp8.py │ ├── benchmark_mla_decode.py │ ├── benchmark_split_kv.py │ ├── block.h │ ├── copy_sm90_bulk_reduce.hpp │ ├── cuda_check.h │ ├── epilogue_bwd.hpp │ ├── epilogue_fwd.hpp │ ├── flash.h │ ├── flash_api.cpp │ ├── flash_api_stable.cpp │ ├── flash_attn_interface.py │ ├── flash_bwd_kernel_sm80.h │ ├── flash_bwd_kernel_sm90.h │ ├── flash_bwd_launch_template.h │ ├── flash_bwd_postprocess_kernel.h │ ├── flash_bwd_preprocess_kernel.h │ ├── flash_fwd_combine.cu │ ├── flash_fwd_combine_kernel.h │ ├── flash_fwd_combine_launch_template.h │ ├── flash_fwd_kernel_sm80.h │ ├── flash_fwd_kernel_sm90.h │ ├── flash_fwd_launch_template.h │ ├── flash_prepare_scheduler.cu │ ├── generate_kernels.py │ ├── heuristics.h │ ├── instantiations/ │ │ ├── flash_bwd_hdim128_bf16_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_sm90.cu │ │ ├── flash_bwd_hdim128_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim128_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_softcapall_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm100.cu │ │ ├── flash_fwd_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_128_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_256_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdimall_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_split_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_split_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_split_sm90.cu │ │ └── flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu │ ├── mainloop_bwd_sm80.hpp │ ├── mainloop_bwd_sm90_tma_gmma_ws.hpp │ ├── mainloop_fwd_sm80.hpp │ ├── mainloop_fwd_sm90_tma_gmma_ws.hpp │ ├── mask.h │ ├── named_barrier.hpp │ ├── pack_gqa.h │ ├── padding.py │ ├── paged_kv.h │ ├── rotary.h │ ├── seqlen.h │ ├── setup.py │ ├── sm90_pipeline_no_cluster.hpp │ ├── softmax.h │ ├── static_switch.h │ ├── test_attn_kvcache.py │ ├── test_flash_attn.py │ ├── test_flash_attn_bwd_determinism.py │ ├── test_flash_attn_triton_amd.py │ ├── test_kvcache.py │ ├── test_torch_compile_and_export.py │ ├── test_util.py │ ├── tile_scheduler.hpp │ ├── tile_size.h │ └── utils.h ├── setup.py ├── tests/ │ ├── cute/ │ │ ├── benchmark_block_sparsity.py │ │ ├── benchmark_mask_mod.py │ │ ├── conftest.py │ │ ├── mask_mod_definitions.py │ │ ├── score_mod_definitions.py │ │ ├── test_block_sparsity.py │ │ ├── test_flash_attn.py │ │ ├── test_flash_attn_combine.py │ │ ├── test_flash_attn_fast.py │ │ ├── test_flash_attn_race_condition.py │ │ ├── test_flash_attn_varlen.py │ │ ├── test_mask_mod.py │ │ ├── test_score_mod.py │ │ ├── test_score_mod_varlen.py │ │ └── test_utils.py │ ├── layers/ │ │ └── test_rotary.py │ ├── losses/ │ │ ├── test_cross_entropy.py │ │ └── test_cross_entropy_parallel.py │ ├── models/ │ │ ├── test_baichuan.py │ │ ├── test_bert.py │ │ ├── test_bigcode.py │ │ ├── test_btlm.py │ │ ├── test_falcon.py │ │ ├── test_gpt.py │ │ ├── test_gpt_generation_parallel.py │ │ ├── test_gpt_neox.py │ │ ├── test_gpt_parallel.py │ │ ├── test_gptj.py │ │ ├── test_llama.py │ │ ├── test_opt.py │ │ └── test_vit.py │ ├── modules/ │ │ ├── test_block_parallel.py │ │ ├── test_embedding_parallel.py │ │ ├── test_mha_parallel.py │ │ └── test_mlp_parallel.py │ ├── ops/ │ │ ├── test_dropout_layer_norm.py │ │ ├── test_fused_dense.py │ │ ├── test_fused_dense_parallel.py │ │ └── triton/ │ │ └── test_layer_norm.py │ ├── pyproject.toml │ ├── test_flash_attn.py │ ├── test_flash_attn_ck.py │ ├── test_flash_attn_triton_amd.py │ ├── test_rotary.py │ └── test_util.py ├── tools/ │ └── sass_diff.py ├── training/ │ ├── Dockerfile │ ├── README.md │ ├── configs/ │ │ ├── callbacks/ │ │ │ ├── causality-monitor.yaml │ │ │ ├── default.yaml │ │ │ ├── ema.yaml │ │ │ ├── flop-count.yaml │ │ │ ├── gpu-monitor.yaml │ │ │ ├── model-summary.yaml │ │ │ ├── none.yaml │ │ │ ├── norm-monitor.yaml │ │ │ ├── params-log.yaml │ │ │ └── wandb.yaml │ │ ├── config.yaml │ │ ├── datamodule/ │ │ │ ├── openwebtext.yaml │ │ │ └── thepile.yaml │ │ ├── experiment/ │ │ │ ├── owt/ │ │ │ │ ├── base.yaml │ │ │ │ ├── gpt2l-flash.yaml │ │ │ │ ├── gpt2l-hf.yaml │ │ │ │ ├── gpt2l.yaml │ │ │ │ ├── gpt2m-flash.yaml │ │ │ │ ├── gpt2m-hf.yaml │ │ │ │ ├── gpt2m.yaml │ │ │ │ ├── gpt2s-flash.yaml │ │ │ │ ├── gpt2s-hf.yaml │ │ │ │ ├── gpt2s.yaml │ │ │ │ ├── gpt2xl-flash.yaml │ │ │ │ ├── gpt2xl-hf.yaml │ │ │ │ └── gpt2xl.yaml │ │ │ └── pile/ │ │ │ ├── base.yaml │ │ │ ├── gpt3-2.7B-flash-8k.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary-8k.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128.yaml │ │ │ ├── gpt3-2.7B-flash-rotary-8k.yaml │ │ │ ├── gpt3-2.7B-flash-rotary.yaml │ │ │ ├── gpt3-2.7B-flash.yaml │ │ │ ├── gpt3-2.7B-hf-hdim128.yaml │ │ │ ├── gpt3-2.7B-hf.yaml │ │ │ ├── gpt3l-flash-8k.yaml │ │ │ ├── gpt3l-flash-rotary-30B.yaml │ │ │ ├── gpt3l-flash-rotary-8k.yaml │ │ │ ├── gpt3l-flash-rotary.yaml │ │ │ ├── gpt3l-flash.yaml │ │ │ ├── gpt3l-hf.yaml │ │ │ ├── gpt3m-flash-8k.yaml │ │ │ ├── gpt3m-flash-rotary-30B.yaml │ │ │ ├── gpt3m-flash-rotary-8k.yaml │ │ │ ├── gpt3m-flash-rotary.yaml │ │ │ ├── gpt3m-flash.yaml │ │ │ ├── gpt3m-hf.yaml │ │ │ ├── gpt3s-flash-8k.yaml │ │ │ ├── gpt3s-flash-rotary-30B.yaml │ │ │ ├── gpt3s-flash-rotary-8k.yaml │ │ │ ├── gpt3s-flash-rotary.yaml │ │ │ ├── gpt3s-flash.yaml │ │ │ ├── gpt3s-hf.yaml │ │ │ ├── gpt3xl-flash-8k.yaml │ │ │ ├── gpt3xl-flash-rotary-60B.yaml │ │ │ ├── gpt3xl-flash-rotary-8k.yaml │ │ │ ├── gpt3xl-flash-rotary.yaml │ │ │ ├── gpt3xl-flash.yaml │ │ │ └── gpt3xl-hf.yaml │ │ ├── logger/ │ │ │ ├── comet.yaml │ │ │ ├── csv.yaml │ │ │ ├── many_loggers.yaml │ │ │ ├── mlflow.yaml │ │ │ ├── neptune.yaml │ │ │ ├── tensorboard.yaml │ │ │ └── wandb.yaml │ │ ├── metrics/ │ │ │ ├── acc.yaml │ │ │ ├── acc_ignore_index.yaml │ │ │ ├── acctop5.yaml │ │ │ ├── mse.yaml │ │ │ ├── num-tokens.yaml │ │ │ └── perplexity.yaml │ │ ├── mode/ │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ ├── exp.yaml │ │ │ ├── profile.yaml │ │ │ └── smoke.yaml │ │ ├── model/ │ │ │ ├── gpt2-hf.yaml │ │ │ ├── gpt2.yaml │ │ │ └── gpt2model/ │ │ │ ├── gpt2-large.yaml │ │ │ ├── gpt2-medium.yaml │ │ │ ├── gpt2-small.yaml │ │ │ └── gpt2-xlarge.yaml │ │ ├── optimizer/ │ │ │ ├── adam.yaml │ │ │ ├── adamw-apex-distributed.yaml │ │ │ ├── adamw-apex-zero.yaml │ │ │ ├── adamw-apex.yaml │ │ │ ├── adamw-zero.yaml │ │ │ ├── adamw.yaml │ │ │ ├── fusedlamb-ds.yaml │ │ │ ├── fusedlamb.yaml │ │ │ └── sgd.yaml │ │ ├── scheduler/ │ │ │ ├── cosine-warmup-timm.yaml │ │ │ ├── cosine-warmup.yaml │ │ │ ├── invsqrt.yaml │ │ │ ├── linear-warmup.yaml │ │ │ ├── multi-step.yaml │ │ │ ├── plateau.yaml │ │ │ ├── poly-warmup.yaml │ │ │ └── step.yaml │ │ ├── task/ │ │ │ └── sequence-model.yaml │ │ └── trainer/ │ │ ├── all_params.yaml │ │ ├── ddp.yaml │ │ ├── debug.yaml │ │ └── default.yaml │ ├── run.py │ ├── src/ │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── causality_monitor.py │ │ │ ├── ema.py │ │ │ ├── flop_count.py │ │ │ ├── gpu_affinity.py │ │ │ ├── loss_scale_monitor.py │ │ │ ├── model_checkpoint.py │ │ │ ├── norm_monitor.py │ │ │ ├── params_log.py │ │ │ ├── speed_monitor.py │ │ │ └── wandb_callbacks.py │ │ ├── datamodules/ │ │ │ ├── datasets/ │ │ │ │ ├── detokenizer.py │ │ │ │ └── lm_dataset.py │ │ │ ├── fault_tolerant_sampler.py │ │ │ ├── imagenet.py │ │ │ ├── language_modeling_hf.py │ │ │ └── timm_mixup.py │ │ ├── distributed/ │ │ │ └── ddp_comm_hooks.py │ │ ├── eval.py │ │ ├── metrics/ │ │ │ ├── accuracy.py │ │ │ ├── num_tokens.py │ │ │ └── perplexity.py │ │ ├── models/ │ │ │ └── modules/ │ │ │ └── seq_common.py │ │ ├── optim/ │ │ │ ├── param_grouping.py │ │ │ └── timm_lr_scheduler.py │ │ ├── tasks/ │ │ │ └── seq.py │ │ ├── train.py │ │ └── utils/ │ │ ├── checkpoint.py │ │ ├── ddp_zero1.py │ │ ├── ddp_zero2.py │ │ ├── distributed.py │ │ ├── ema.py │ │ ├── flops.py │ │ ├── gpu_affinity.py │ │ └── utils.py │ └── tests/ │ └── datamodules/ │ └── test_language_modeling_hf.py └── usage.md
SYMBOL INDEX (2142 symbols across 245 files)
FILE: AI/racecheck_repro_1d_bulk.py
function kernel (line 24) | def kernel(g_src: cute.Tensor, g_dst: cute.Tensor):
function go (line 63) | def go(g_src, g_dst, stream):
FILE: AI/racecheck_repro_1d_tensor.py
function kernel (line 24) | def kernel(g_dst: cute.Tensor, tma_atom: cute.CopyAtom, tma_tensor: cute...
function go (line 68) | def go(g_src, g_dst, stream):
FILE: benchmarks/bench_sm90.py
function parse_int_k (line 45) | def parse_int_k(s):
function csv_ints (line 53) | def csv_ints(s):
function parse_headdims (line 58) | def parse_headdims(s):
function nheads_for_hdim (line 78) | def nheads_for_hdim(h):
function fwd_flops (line 82) | def fwd_flops(batch, nheads, seqlen, hdim, hdim_v=None, causal=False):
function bwd_flops (line 89) | def bwd_flops(batch, nheads, seqlen, hdim, causal=False, hdim_v=None):
function get_causals (line 93) | def get_causals(args):
function auto_batch (line 101) | def auto_batch(seqlen, batch_arg, total_tokens=32768):
function bench_fwd (line 107) | def bench_fwd(batch, seqlen, nheads, hdim, causal, tile_m=None, tile_n=N...
function bench_bwd (line 154) | def bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=5, rep=30, hdi...
function _get_default_bwd_config (line 257) | def _get_default_bwd_config(headdim, causal=False):
function run_default (line 334) | def run_default(args):
function run_sweep_tiles (line 367) | def run_sweep_tiles(args):
function run_sweep_rs_overlap (line 397) | def run_sweep_rs_overlap(args):
function run_compare_configs (line 427) | def run_compare_configs(args):
function run_sweep_bwd_opts (line 452) | def run_sweep_bwd_opts(args):
function main (line 489) | def main():
FILE: benchmarks/benchmark_alibi.py
function generate_cos_sin (line 23) | def generate_cos_sin(seqlen, rotary_dim, device, dtype):
function flash_rotary (line 31) | def flash_rotary(q, k, v, cos, sin, causal=False):
function attn_bias_from_alibi_slopes (line 43) | def attn_bias_from_alibi_slopes(
function flops (line 68) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
function efficiency (line 74) | def efficiency(flop, time):
function attention_pytorch (line 78) | def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
function time_fwd_bwd (line 110) | def time_fwd_bwd(func, *args, **kwargs):
FILE: benchmarks/benchmark_attn.py
function _make_bwd_fn (line 43) | def _make_bwd_fn(fwd_fn, g, inputs):
function setup_standard (line 67) | def setup_standard(ctx):
function setup_fa2 (line 76) | def setup_fa2(ctx):
function setup_cudnn (line 95) | def setup_cudnn(ctx):
function setup_fa3 (line 110) | def setup_fa3(ctx):
function setup_fa4 (line 134) | def setup_fa4(ctx):
function parse_int_k (line 173) | def parse_int_k(s):
function csv_ints (line 181) | def csv_ints(s):
function parse_headdims (line 186) | def parse_headdims(s):
function csv_strs (line 206) | def csv_strs(s):
function parse_args (line 211) | def parse_args():
function main (line 242) | def main():
FILE: benchmarks/benchmark_causal.py
function attention_pytorch (line 20) | def attention_pytorch(qkv, dropout_p=0.0, causal=True):
function time_fwd_bwd (line 122) | def time_fwd_bwd(func, *args, **kwargs):
FILE: benchmarks/benchmark_flash_attention.py
function flops (line 27) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
function efficiency (line 32) | def efficiency(flop, time):
function attention_pytorch (line 36) | def attention_pytorch(qkv, dropout_p=0.0, causal=True):
function time_fwd_bwd (line 65) | def time_fwd_bwd(func, *args, **kwargs):
FILE: benchmarks/benchmark_gemm.py
function benchmark_forward (line 12) | def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **...
FILE: csrc/flash_attn/flash_api.cpp
type FLASH_NAMESPACE (line 24) | namespace FLASH_NAMESPACE {
function set_params_fprop (line 26) | void set_params_fprop(Flash_fwd_params ¶ms,
function set_params_dgrad (line 161) | void set_params_dgrad(Flash_bwd_params ¶ms,
function run_mha_fwd (line 243) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool f...
function num_splits_heuristic (line 263) | inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs,...
function set_params_splitkv (line 299) | std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params...
function set_params_alibi (line 331) | void set_params_alibi(Flash_fwd_params ¶ms, std::optional<at::Tens...
function mha_fwd (line 350) | std::vector<at::Tensor>
function mha_varlen_fwd (line 514) | std::vector<at::Tensor>
function run_mha_bwd (line 757) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
function mha_bwd (line 767) | std::vector<at::Tensor>
function mha_varlen_bwd (line 973) | std::vector<at::Tensor>
function mha_fwd_kvcache (line 1202) | std::vector<at::Tensor>
function PYBIND11_MODULE (line 1478) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/flash_attn/src/alibi.h
function namespace (line 11) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/block_info.h
function namespace (line 8) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/dropout.h
function namespace (line 11) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/flash.h
function namespace (line 14) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/flash_bwd_kernel.h
function namespace (line 23) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/flash_bwd_launch_template.h
function namespace (line 16) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
function namespace (line 18) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/flash_fwd_kernel.h
function namespace (line 24) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/generate_kernels.py
function get_fwd_template (line 17) | def get_fwd_template() -> str:
function get_fwd_split_template (line 29) | def get_fwd_split_template() -> str:
function get_bwd_template (line 38) | def get_bwd_template() -> str:
class Kernel (line 51) | class Kernel:
method template (line 59) | def template(self) -> str:
method filename (line 73) | def filename(self) -> str:
function get_all_kernels (line 76) | def get_all_kernels() -> List[Kernel]:
function write_kernel (line 81) | def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
function main (line 88) | def main(output_dir: Optional[str]) -> None:
FILE: csrc/flash_attn/src/hardware_info.h
function get_current_device (line 24) | inline int get_current_device() {
function get_num_sm (line 37) | inline int get_num_sm(int device) {
FILE: csrc/flash_attn/src/mask.h
function namespace (line 10) | namespace FLASH_NAMESPACE {
function apply_mask_local (line 39) | void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_...
FILE: csrc/flash_attn/src/rotary.h
function namespace (line 14) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn/src/softmax.h
function namespace (line 17) | namespace FLASH_NAMESPACE {
function __forceinline__ (line 134) | __forceinline__ __device__ Softmax() {}
FILE: csrc/flash_attn/src/utils.h
function namespace (line 28) | namespace FLASH_NAMESPACE {
FILE: csrc/flash_attn_ck/flash_api.cpp
function PYBIND11_MODULE (line 114) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
FILE: csrc/flash_attn_ck/flash_common.cpp
type flash (line 7) | namespace flash {
function override_num_splits_if_necessary (line 8) | int override_num_splits_if_necessary(int batch, int nhead, int max_seq...
FILE: csrc/flash_attn_ck/flash_common.hpp
type flash (line 24) | namespace flash {
function __global__ (line 25) | inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, u...
function num_splits_heuristic_ck (line 38) | inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_S...
FILE: csrc/flash_attn_ck/mha_bwd.cpp
function fmha_bwd_traits (line 10) | fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
function fmha_bwd_args (line 41) | fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
function mha_bwd (line 211) | std::vector<at::Tensor>
FILE: csrc/flash_attn_ck/mha_fwd.cpp
function fmha_fwd_traits (line 10) | fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
function fmha_fwd_args (line 30) | fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
function mha_fwd (line 165) | std::vector<at::Tensor>
FILE: csrc/flash_attn_ck/mha_fwd_kvcache.cpp
function fmha_fwd_appendkv_traits (line 10) | fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,
function fmha_fwd_splitkv_traits (line 26) | fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &...
function fmha_fwd_appendkv_args (line 45) | fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
function fmha_fwd_splitkv_args (line 138) | fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
function mha_fwd_kvcache (line 272) | std::vector<at::Tensor>
FILE: csrc/flash_attn_ck/mha_varlen_bwd.cpp
function fmha_bwd_traits (line 10) | fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
function fmha_bwd_args (line 42) | fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
function mha_varlen_bwd (line 218) | std::vector<at::Tensor>
FILE: csrc/flash_attn_ck/mha_varlen_fwd.cpp
function fmha_fwd_traits (line 10) | fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
function fmha_fwd_splitkv_traits (line 30) | fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask...
function fmha_fwd_args (line 49) | fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
function fmha_fwd_splitkv_args (line 187) | fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
function mha_varlen_fwd (line 320) | std::vector<at::Tensor>
FILE: csrc/fused_dense_lib/fused_dense.cpp
function linear_bias_wgrad (line 40) | std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d...
function linear_act_forward (line 92) | std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor ...
function bias_act_linear_dgrad_bgrad (line 154) | std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
function PYBIND11_MODULE (line 209) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/fused_dense_lib/setup.py
function get_cuda_bare_metal_version (line 10) | def get_cuda_bare_metal_version(cuda_dir):
function append_nvcc_threads (line 19) | def append_nvcc_threads(nvcc_extra_args):
FILE: csrc/layer_norm/ln.h
function namespace (line 13) | namespace layer_norm {
FILE: csrc/layer_norm/ln_api.cpp
type layer_norm (line 27) | namespace layer_norm {
function get_type_id (line 36) | uint32_t get_type_id(torch::Dtype dtype){
function get_key (line 50) | uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype ...
function dropout_add_ln_fwd (line 105) | std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, //...
function dropout_add_ln_bwd (line 282) | std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // ...
function dropout_add_ln_parallel_residual_fwd (line 482) | std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
function dropout_add_ln_parallel_residual_bwd (line 649) | std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
function PYBIND11_MODULE (line 826) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/layer_norm/ln_kernel_traits.h
function namespace (line 5) | namespace layer_norm {
function Base (line 110) | struct Kernel_traits : public Base {
FILE: csrc/layer_norm/setup.py
function get_cuda_bare_metal_version (line 16) | def get_cuda_bare_metal_version(cuda_dir):
function check_cuda_torch_binary_vs_bare_metal (line 25) | def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
function raise_if_cuda_home_none (line 43) | def raise_if_cuda_home_none(global_option: str) -> None:
function append_nvcc_threads (line 53) | def append_nvcc_threads(nvcc_extra_args):
FILE: flash_attn/bert_padding.py
class IndexFirstAxis (line 8) | class IndexFirstAxis(torch.autograd.Function):
method forward (line 10) | def forward(ctx, input, indices):
method backward (line 22) | def backward(ctx, grad_output):
class IndexPutFirstAxis (line 41) | class IndexPutFirstAxis(torch.autograd.Function):
method forward (line 43) | def forward(ctx, values, indices, first_axis_dim):
method backward (line 56) | def backward(ctx, grad_output):
class IndexFirstAxisResidual (line 67) | class IndexFirstAxisResidual(torch.autograd.Function):
method forward (line 69) | def forward(ctx, input, indices):
method backward (line 82) | def backward(ctx, grad_output, grad_residual):
function unpad_input (line 98) | def unpad_input(hidden_states, attention_mask, unused_mask=None):
function unpad_input_for_concatenated_sequences (line 131) | def unpad_input_for_concatenated_sequences(hidden_states, attention_mask...
function pad_input (line 204) | def pad_input(hidden_states, indices, batch, seqlen):
FILE: flash_attn/cute/ampere_helpers.py
function get_smem_layout_atom (line 8) | def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cu...
function gemm (line 35) | def gemm(
function gemm_rs (line 87) | def gemm_rs(
FILE: flash_attn/cute/barrier.py
function ld_acquire (line 9) | def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass....
function red_relaxed (line 24) | def red_relaxed(
function red_release (line 40) | def red_release(
function wait_eq (line 56) | def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset...
function arrive_inc (line 65) | def arrive_inc(
FILE: flash_attn/cute/bench_utils.py
function flops (line 15) | def flops(
function attention_ref (line 46) | def attention_ref(q, k, v, causal=False):
function _build_cudnn_graph (line 79) | def _build_cudnn_graph(io_dtype, tensors, build_fn):
function cudnn_fwd_setup (line 99) | def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None):
function cudnn_bwd_setup (line 141) | def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=N...
FILE: flash_attn/cute/benchmark.py
function benchmark_forward (line 8) | def benchmark_forward(
function benchmark_backward (line 30) | def benchmark_backward(
function benchmark_combined (line 72) | def benchmark_combined(
function benchmark_fwd_bwd (line 117) | def benchmark_fwd_bwd(
function benchmark_all (line 154) | def benchmark_all(
function pytorch_profiler (line 202) | def pytorch_profiler(
function benchmark_memory (line 258) | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
FILE: flash_attn/cute/blackwell_helpers.py
function gemm_w_idx (line 14) | def gemm_w_idx(
function gemm_ptx_w_idx (line 42) | def gemm_ptx_w_idx(
function gemm (line 77) | def gemm(
function i64_to_i32x2 (line 90) | def i64_to_i32x2(i: int) -> Tuple[int, int]:
function gemm_ptx (line 96) | def gemm_ptx(
function gemm_ptx_loop (line 211) | def gemm_ptx_loop(
function gemm_ptx_partial (line 374) | def gemm_ptx_partial(
function gemm_ptx_partial1 (line 594) | def gemm_ptx_partial1(
function gemm_ptx_precomputed (line 773) | def gemm_ptx_precomputed(
function declare_ptx_smem_desc (line 952) | def declare_ptx_smem_desc(
function declare_ptx_idesc (line 996) | def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = ...
function gemm_ptx_precomputed_varname (line 1011) | def gemm_ptx_precomputed_varname(
FILE: flash_attn/cute/block_info.py
class BlockInfo (line 13) | class BlockInfo:
method get_n_block_min_max (line 24) | def get_n_block_min_max(
method get_m_block_min_max (line 58) | def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int3...
method get_n_block_k_new_min_max (line 74) | def get_n_block_k_new_min_max(
method get_n_block_min_causal_local_mask (line 105) | def get_n_block_min_causal_local_mask(
method get_n_block_min_before_local_mask (line 124) | def get_n_block_min_before_local_mask(
FILE: flash_attn/cute/block_sparse_utils.py
function load_block_list (line 72) | def load_block_list(
function finish_overlap_v_load (line 127) | def finish_overlap_v_load(
function sparse_tensor_m_block (line 145) | def sparse_tensor_m_block(
function produce_block_sparse_loads (line 160) | def produce_block_sparse_loads(
function consume_block_sparse_loads (line 303) | def consume_block_sparse_loads(
function load_block_list_sm100 (line 487) | def load_block_list_sm100(
function produce_block_sparse_loads_sm100 (line 529) | def produce_block_sparse_loads_sm100(
function get_total_block_count (line 622) | def get_total_block_count(
function handle_block_sparse_empty_tile_correction_sm100 (line 643) | def handle_block_sparse_empty_tile_correction_sm100(
function softmax_block_sparse_sm100 (line 756) | def softmax_block_sparse_sm100(
function get_total_q_block_count_bwd (line 896) | def get_total_q_block_count_bwd(
function produce_block_sparse_q_loads_bwd_sm100 (line 913) | def produce_block_sparse_q_loads_bwd_sm100(
function get_block_sparse_iteration_info_bwd (line 1037) | def get_block_sparse_iteration_info_bwd(
function get_m_block_from_iter_bwd (line 1069) | def get_m_block_from_iter_bwd(
function _load_q_do_block_sm90 (line 1102) | def _load_q_do_block_sm90(
function produce_block_sparse_q_loads_bwd_sm90 (line 1145) | def produce_block_sparse_q_loads_bwd_sm90(
function consume_block_sparse_mma_bwd_sm90 (line 1241) | def consume_block_sparse_mma_bwd_sm90(
function _store_one_dQaccum_sm90 (line 1347) | def _store_one_dQaccum_sm90(
function dQaccum_store_block_sparse_bwd_sm90 (line 1377) | def dQaccum_store_block_sparse_bwd_sm90(
FILE: flash_attn/cute/block_sparsity.py
function ceildiv (line 13) | def ceildiv(a: int, b: int) -> int:
class BlockSparseTensors (line 17) | class BlockSparseTensors(NamedTuple):
method __new_from_mlir_values__ (line 23) | def __new_from_mlir_values__(self, values):
class BlockSparseTensorsTorch (line 29) | class BlockSparseTensorsTorch(NamedTuple):
function _expand_sparsity_tensor (line 37) | def _expand_sparsity_tensor(
function _check_and_expand_block (line 60) | def _check_and_expand_block(
function get_block_sparse_expected_shapes (line 90) | def get_block_sparse_expected_shapes(
function infer_block_sparse_expected_shapes (line 108) | def infer_block_sparse_expected_shapes(
function get_block_sparse_expected_shapes_bwd (line 202) | def get_block_sparse_expected_shapes_bwd(
function normalize_block_sparse_tensors (line 225) | def normalize_block_sparse_tensors(
function is_block_sparsity_enabled (line 269) | def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
function get_block_sparse_broadcast_pattern (line 273) | def get_block_sparse_broadcast_pattern(
function normalize_block_sparse_config (line 305) | def normalize_block_sparse_config(
function normalize_block_sparse_config_bwd (line 351) | def normalize_block_sparse_config_bwd(
function to_cute_block_sparse_tensors (line 398) | def to_cute_block_sparse_tensors(
function fast_sampling (line 437) | def fast_sampling(mask_mod):
FILE: flash_attn/cute/cache_utils.py
function get_cache_path (line 49) | def get_cache_path() -> Path:
function _compute_source_fingerprint (line 59) | def _compute_source_fingerprint() -> str:
class FileLock (line 88) | class FileLock:
method __init__ (line 99) | def __init__(
method _lock_label (line 120) | def _lock_label(self) -> str:
method __enter__ (line 124) | def __enter__(self) -> "FileLock":
method __exit__ (line 149) | def __exit__(self, exc_type, exc_val, exc_tb) -> None:
class JITCache (line 156) | class JITCache:
method __init__ (line 161) | def __init__(self):
method __setitem__ (line 164) | def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) ->...
method __getitem__ (line 167) | def __getitem__(self, key: CompileKeyType) -> CallableFunction:
method __contains__ (line 170) | def __contains__(self, key: CompileKeyType) -> bool:
method clear (line 173) | def clear(self) -> None:
class JITPersistentCache (line 180) | class JITPersistentCache(JITCache):
method __init__ (line 189) | def __init__(self, cache_path: Path):
method __setitem__ (line 194) | def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) ->...
method __getitem__ (line 198) | def __getitem__(self, key: CompileKeyType) -> CallableFunction:
method __contains__ (line 203) | def __contains__(self, key: CompileKeyType) -> bool:
method _try_load_from_storage (line 210) | def _try_load_from_storage(self, key: CompileKeyType) -> bool:
method _try_export_to_storage (line 234) | def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledF...
method _key_to_hash (line 255) | def _key_to_hash(self, key: CompileKeyType) -> str:
method _lock_path (line 258) | def _lock_path(self, sha256_hex: str) -> Path:
method clear (line 261) | def clear(self) -> None:
function get_jit_cache (line 271) | def get_jit_cache(name: str | None = None) -> JITCache:
FILE: flash_attn/cute/compute_block_sparsity.py
class BlockSparsityKernel (line 18) | class BlockSparsityKernel:
method __init__ (line 35) | def __init__(
method __call__ (line 50) | def __call__(
method kernel (line 87) | def kernel(
function compute_block_sparsity (line 277) | def compute_block_sparsity(
FILE: flash_attn/cute/copy_utils.py
function cvt_copy (line 17) | def cvt_copy(
function load_s2r (line 36) | def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
function get_copy_atom (line 43) | def get_copy_atom(
function make_tmem_copy (line 52) | def make_tmem_copy(
function copy (line 66) | def copy(
function tiled_copy_1d (line 81) | def tiled_copy_1d(
function tiled_copy_2d (line 92) | def tiled_copy_2d(
function atomic_add_fp32x4 (line 110) | def atomic_add_fp32x4(
function set_block_rank (line 148) | def set_block_rank(
function store_shared_remote_fp32x4 (line 167) | def store_shared_remote_fp32x4(
function cpasync_bulk_s2cluster (line 211) | def cpasync_bulk_s2cluster(
function cpasync_bulk_g2s (line 243) | def cpasync_bulk_g2s(
function cpasync_reduce_bulk_add_f32 (line 267) | def cpasync_reduce_bulk_add_f32(
function cpasync_bulk_get_copy_fn (line 291) | def cpasync_bulk_get_copy_fn(
function tma_get_copy_fn (line 324) | def tma_get_copy_fn(
function tma_producer_copy_fn (line 363) | def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.Pipe...
FILE: flash_attn/cute/cute_dsl_ptxas.py
function _log (line 25) | def _log(msg):
function _get_ptx (line 30) | def _get_ptx(compiled_func) -> tuple[str, Path] | None:
function _compile_ptx (line 45) | def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
function _patched_load_cuda_library (line 81) | def _patched_load_cuda_library(self):
function patch (line 132) | def patch():
FILE: flash_attn/cute/cute_dsl_utils.py
function get_max_active_clusters (line 37) | def get_max_active_clusters(cluster_size):
function get_device_capacity (line 42) | def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
class ArgumentsBase (line 47) | class ArgumentsBase(JitArgument):
method __c_pointers__ (line 48) | def __c_pointers__(self):
method __get_mlir_types__ (line 57) | def __get_mlir_types__(self):
method __new_from_mlir_values__ (line 70) | def __new_from_mlir_values__(self, values):
function load_cubin_module_data_patched (line 82) | def load_cubin_module_data_patched(cubin_data, filepath):
function cute_compile_patched (line 87) | def cute_compile_patched(*args, **kwargs):
function assume_strides_aligned (line 103) | def assume_strides_aligned(t):
function assume_tensor_aligned (line 114) | def assume_tensor_aligned(t):
function to_cute_tensor (line 121) | def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=Fa...
function to_cute_aux_tensor (line 131) | def to_cute_aux_tensor(t, enable_tvm_ffi=True):
function get_aux_tensor_metadata (line 149) | def get_aux_tensor_metadata(aux_tensors):
function get_broadcast_dims (line 160) | def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
FILE: flash_attn/cute/fa_logging.py
function _parse_log_level (line 38) | def _parse_log_level(raw: str) -> int:
function _configure_default_handler (line 55) | def _configure_default_handler() -> None:
function get_fa_log_level (line 73) | def get_fa_log_level() -> int:
function set_fa_log_level (line 77) | def set_fa_log_level(level: int | str) -> None:
function fa_log (line 90) | def fa_log(level: int, msg: str):
function fa_printf (line 95) | def fa_printf(level: int, fmt, *args):
FILE: flash_attn/cute/fast_math.py
function clz (line 9) | def clz(x: Int32) -> Int32:
FILE: flash_attn/cute/flash_bwd.py
class FlashAttentionBackwardSm80 (line 28) | class FlashAttentionBackwardSm80:
method __init__ (line 29) | def __init__(
method can_implement (line 95) | def can_implement(
method _check_type (line 141) | def _check_type(
method _setup_attributes (line 183) | def _setup_attributes(self):
method _get_tiled_mma (line 300) | def _get_tiled_mma(self):
method _get_shared_storage_cls (line 322) | def _get_shared_storage_cls(self):
method __call__ (line 364) | def __call__(
method kernel (line 477) | def kernel(
method compute_one_m_block (line 851) | def compute_one_m_block(
method epilogue (line 1008) | def epilogue(
method advance_pipeline (line 1137) | def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constex...
method load_K (line 1141) | def load_K(
method load_V (line 1170) | def load_V(
method load_Q_LSE (line 1198) | def load_Q_LSE(
method load_dO_dPsum (line 1242) | def load_dO_dPsum(
FILE: flash_attn/cute/flash_bwd_postprocess.py
class FlashAttentionBackwardPostprocess (line 34) | class FlashAttentionBackwardPostprocess:
method __init__ (line 35) | def __init__(
method can_implement (line 70) | def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
method _get_tiled_mma (line 91) | def _get_tiled_mma(self):
method _setup_attributes (line 132) | def _setup_attributes(self):
method __call__ (line 211) | def __call__(
method kernel (line 290) | def kernel(
FILE: flash_attn/cute/flash_bwd_preprocess.py
class FlashAttentionBackwardPreprocess (line 38) | class FlashAttentionBackwardPreprocess:
method __init__ (line 39) | def __init__(
method can_implement (line 69) | def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
method _setup_attributes (line 94) | def _setup_attributes(self):
method __call__ (line 125) | def __call__(
method kernel (line 221) | def kernel(
FILE: flash_attn/cute/flash_bwd_sm100.py
class FlashAttentionBackwardSm100 (line 47) | class FlashAttentionBackwardSm100:
method __init__ (line 50) | def __init__(
method _setup_attributes (line 236) | def _setup_attributes(self):
method _get_tiled_mma (line 266) | def _get_tiled_mma(self):
method _setup_smem_layout (line 320) | def _setup_smem_layout(self):
method __call__ (line 443) | def __call__(
method kernel (line 1009) | def kernel(
method relay (line 1621) | def relay(
method load (line 1665) | def load(
method mma (line 2196) | def mma(
method split_wg (line 2699) | def split_wg(
method apply_score_mod (line 2741) | def apply_score_mod(
method apply_score_mod_bwd (line 2780) | def apply_score_mod_bwd(
method compute_loop (line 2812) | def compute_loop(
method dQacc_reduce (line 3417) | def dQacc_reduce(
method epilogue_dKV (line 3657) | def epilogue_dKV(
method epilogue_dK_or_dV_tma (line 3794) | def epilogue_dK_or_dV_tma(
FILE: flash_attn/cute/flash_bwd_sm120.py
class FlashAttentionBackwardSm120 (line 14) | class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80):
method can_implement (line 16) | def can_implement(
FILE: flash_attn/cute/flash_bwd_sm90.py
class FlashAttentionBackwardSm90 (line 45) | class FlashAttentionBackwardSm90:
method __init__ (line 48) | def __init__(
method can_implement (line 145) | def can_implement(
method _check_type (line 169) | def _check_type(
method _setup_attributes (line 200) | def _setup_attributes(self):
method _get_tiled_mma (line 249) | def _get_tiled_mma(self):
method _get_shared_storage_cls (line 299) | def _get_shared_storage_cls(self):
method __call__ (line 337) | def __call__(
method kernel (line 613) | def kernel(
method load (line 833) | def load(
method apply_score_mod (line 1001) | def apply_score_mod(
method apply_score_mod_bwd (line 1044) | def apply_score_mod_bwd(
method mma (line 1088) | def mma(
method _get_stat (line 1428) | def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: boo...
method mma_one_m_block (line 1443) | def mma_one_m_block(
method epilogue_dKV (line 1603) | def epilogue_dKV(
method dQaccum_store (line 1738) | def dQaccum_store(
FILE: flash_attn/cute/flash_fwd.py
class FlashAttentionForwardBase (line 39) | class FlashAttentionForwardBase:
method __init__ (line 41) | def __init__(
method can_implement (line 113) | def can_implement(
method _check_type (line 170) | def _check_type(
method _setup_attributes (line 199) | def _setup_attributes(self):
method _get_smem_layout_atom (line 295) | def _get_smem_layout_atom(self):
method _get_tiled_mma (line 298) | def _get_tiled_mma(self):
method _get_shared_storage_cls (line 301) | def _get_shared_storage_cls(self):
method __call__ (line 305) | def __call__(
method epilogue (line 324) | def epilogue(
method advance_pipeline (line 449) | def advance_pipeline(self, pipeline_index):
method load_Q (line 453) | def load_Q(
method load_K (line 480) | def load_K(
method load_V (line 526) | def load_V(
class FlashAttentionForwardSm80 (line 576) | class FlashAttentionForwardSm80(FlashAttentionForwardBase):
method _get_smem_layout_atom (line 577) | def _get_smem_layout_atom(self):
method _get_tiled_mma (line 585) | def _get_tiled_mma(self):
method _get_shared_storage_cls (line 598) | def _get_shared_storage_cls(self):
method __call__ (line 620) | def __call__(
method kernel (line 743) | def kernel(
method compute_one_n_block (line 1082) | def compute_one_n_block(
function __getattr__ (line 1195) | def __getattr__(name):
FILE: flash_attn/cute/flash_fwd_combine.py
class FlashAttentionForwardCombine (line 21) | class FlashAttentionForwardCombine:
method __init__ (line 22) | def __init__(
method can_implement (line 57) | def can_implement(
method _setup_attributes (line 84) | def _setup_attributes(self):
method __call__ (line 191) | def __call__(
method kernel (line 324) | def kernel(
method load_O_partial (line 668) | def load_O_partial(
FILE: flash_attn/cute/flash_fwd_sm100.py
class FlashAttentionForwardSm100 (line 64) | class FlashAttentionForwardSm100:
method __init__ (line 66) | def __init__(
method _setup_attributes (line 243) | def _setup_attributes(self):
method __call__ (line 284) | def __call__(
method kernel (line 665) | def kernel(
method load (line 1122) | def load(
method mma (line 1323) | def mma(
method softmax_loop (line 1614) | def softmax_loop(
method softmax_step (line 1952) | def softmax_step(
method correction_loop (line 2089) | def correction_loop(
method correction_rescale (line 2368) | def correction_rescale(
method correction_epilogue (line 2419) | def correction_epilogue(
method _store_O_to_gmem (line 2508) | def _store_O_to_gmem(
method epilogue_s2g (line 2556) | def epilogue_s2g(
method load_Q (line 2626) | def load_Q(
method load_KV (line 2638) | def load_KV(
method offset_kv_smem (line 2687) | def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32):
method apply_score_mod (line 2722) | def apply_score_mod(
FILE: flash_attn/cute/flash_fwd_sm120.py
class FlashAttentionForwardSm120 (line 14) | class FlashAttentionForwardSm120(FlashAttentionForwardSm80):
method can_implement (line 20) | def can_implement(
FILE: flash_attn/cute/flash_fwd_sm90.py
class FlashAttentionForwardSm90 (line 51) | class FlashAttentionForwardSm90(FlashAttentionForwardBase):
method __init__ (line 52) | def __init__(
method _get_smem_layout_atom (line 71) | def _get_smem_layout_atom(self):
method _get_tiled_mma (line 95) | def _get_tiled_mma(self):
method _get_shared_storage_cls (line 119) | def _get_shared_storage_cls(self):
method __call__ (line 157) | def __call__(
method kernel (line 398) | def kernel(
method load (line 625) | def load(
method load_KV (line 862) | def load_KV(
method mma (line 883) | def mma(
method first_half_block_overlap (line 1236) | def first_half_block_overlap(
method last_half_block_overlap (line 1292) | def last_half_block_overlap(
method mma_one_n_block (line 1315) | def mma_one_n_block(
method mma_one_n_block_intrawg_overlap (line 1377) | def mma_one_n_block_intrawg_overlap(
method mma_init (line 1446) | def mma_init(self):
method apply_score_mod (line 1456) | def apply_score_mod(
method warp_scheduler_barrier_sync (line 1490) | def warp_scheduler_barrier_sync(self):
method warp_scheduler_barrier_arrive (line 1499) | def warp_scheduler_barrier_arrive(self):
FILE: flash_attn/cute/interface.py
function _parse_arch_str (line 72) | def _parse_arch_str(arch_str):
function _get_device_arch (line 83) | def _get_device_arch():
function _validate_head_dims (line 101) | def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capabili...
class FwdConfig (line 120) | class FwdConfig:
function _tile_size_fwd_sm90 (line 127) | def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_b...
class BwdConfig (line 156) | class BwdConfig:
function _tile_size_bwd_sm90 (line 172) | def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local):
function maybe_contiguous (line 236) | def maybe_contiguous(x):
function _validate_tensor (line 240) | def _validate_tensor(t, name, expected_shape, expected_dtype, expected_d...
function num_splits_heuristic (line 255) | def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
function _resolve_causal_local_window (line 265) | def _resolve_causal_local_window(causal, window_size_left, window_size_r...
function _flash_attn_fwd (line 288) | def _flash_attn_fwd(
function make_fake_bwd_tensors (line 822) | def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k):
function _compile_bwd_preprocess (line 866) | def _compile_bwd_preprocess(
function _bwd_preprocess (line 886) | def _bwd_preprocess(
function _compile_bwd_postprocess (line 907) | def _compile_bwd_postprocess(
function _bwd_postprocess_convert (line 932) | def _bwd_postprocess_convert(
function _flash_attn_bwd (line 956) | def _flash_attn_bwd(
class FlashAttnFunc (line 1564) | class FlashAttnFunc(torch.autograd.Function):
method forward (line 1566) | def forward(
method backward (line 1624) | def backward(ctx, dout, dlse):
class FlashAttnVarlenFunc (line 1648) | class FlashAttnVarlenFunc(torch.autograd.Function):
method forward (line 1650) | def forward(
method backward (line 1710) | def backward(ctx, dout, dlse):
function flash_attn_func (line 1742) | def flash_attn_func(
function flash_attn_varlen_func (line 1784) | def flash_attn_varlen_func(
function _compile_fwd_combine (line 1832) | def _compile_fwd_combine(
function _flash_attn_fwd_combine (line 1889) | def _flash_attn_fwd_combine(
function flash_attn_combine (line 1983) | def flash_attn_combine(
FILE: flash_attn/cute/mask.py
function r2p_bitmask_below (line 19) | def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
function r2p_bitmask_above (line 30) | def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
function mask_r2p_lambda (line 41) | def mask_r2p_lambda(
function sm90_col_to_r2p_idx (line 69) | def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32:
function row_to_r2p_idx (line 80) | def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
class AttentionMask (line 103) | class AttentionMask:
method seqlen_q (line 113) | def seqlen_q(self) -> Int32:
method seqlen_k (line 117) | def seqlen_k(self) -> Int32:
method apply_mask (line 121) | def apply_mask(
method apply_mask_sm100 (line 370) | def apply_mask_sm100(
method apply_mask_sm100_transposed (line 530) | def apply_mask_sm100_transposed(
FILE: flash_attn/cute/mma_sm100_desc.py
class Major (line 16) | class Major(IntEnum): # matrix “layout” in the ISA docs
class ScaleIn (line 21) | class ScaleIn(IntEnum): # negate flags
class Saturate (line 26) | class Saturate(IntEnum):
class CFormat (line 31) | class CFormat(IntEnum): # 2-bit field (bits 4-5)
class F16F32Format (line 37) | class F16F32Format(IntEnum): # 3-bit field (A/B element type)
class S8Format (line 43) | class S8Format(IntEnum):
class MXF8F6F4Format (line 48) | class MXF8F6F4Format(IntEnum):
class MaxShift (line 56) | class MaxShift(IntEnum):
function to_UMMA_format (line 68) | def to_UMMA_format(cutlass_type) -> int:
function to_C_format (line 93) | def to_C_format(cutlass_type) -> int:
function make_instr_desc (line 111) | def make_instr_desc(
function mma_op_to_idesc (line 165) | def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
class LayoutType (line 177) | class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
function _layout_type (line 191) | def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
function make_smem_desc_base (line 212) | def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, majo...
function make_smem_desc_start_addr (line 285) | def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
function smem_desc_base_from_tensor (line 290) | def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:
FILE: flash_attn/cute/named_barrier.py
class NamedBarrierFwd (line 6) | class NamedBarrierFwd(enum.IntEnum):
class NamedBarrierFwdSm100 (line 15) | class NamedBarrierFwdSm100(enum.IntEnum):
class NamedBarrierBwd (line 28) | class NamedBarrierBwd(enum.IntEnum):
class NamedBarrierBwdSm100 (line 42) | class NamedBarrierBwdSm100(enum.IntEnum):
FILE: flash_attn/cute/pack_gqa.py
function pack_gqa_layout (line 14) | def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):
function make_packgqa_tiled_tma_atom (line 42) | def make_packgqa_tiled_tma_atom(
function unpack_gqa_layout (line 85) | def unpack_gqa_layout(T, qhead_per_kvhead, head_idx):
class PackGQA (line 114) | class PackGQA:
method __init__ (line 115) | def __init__(
method compute_ptr (line 128) | def compute_ptr(
method load_Q (line 148) | def load_Q(
method store_LSE (line 193) | def store_LSE(
method store_O (line 228) | def store_O(
FILE: flash_attn/cute/paged_kv.py
class PagedKVManager (line 17) | class PagedKVManager(ParamsBase):
method create (line 46) | def create(
method load_page_table (line 136) | def load_page_table(self, n_block: Int32):
method compute_X_ptr (line 157) | def compute_X_ptr(self, K_or_V: str):
method load_KV (line 173) | def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
FILE: flash_attn/cute/pipeline.py
class PipelineStateSimple (line 20) | class PipelineStateSimple:
method __init__ (line 27) | def __init__(self, stages: int, phase_index: Int32):
method clone (line 34) | def clone(self) -> "PipelineStateSimple":
method stages (line 38) | def stages(self) -> int:
method index (line 43) | def index(self) -> Int32:
method phase (line 52) | def phase(self) -> Int32:
method advance (line 63) | def advance(self):
method __extract_mlir_values__ (line 84) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 88) | def __new_from_mlir_values__(self, values):
function make_pipeline_state (line 92) | def make_pipeline_state(type: PipelineUserType, stages: int):
class NamedBarrier (line 106) | class NamedBarrier(NamedBarrierOg):
method create (line 108) | def create(*args, **kwargs):
method arrive_w_index (line 115) | def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
method arrive_and_wait_w_index (line 128) | def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) ...
class PipelineAsync (line 138) | class PipelineAsync(PipelineAsyncOg):
method create (line 140) | def create(*args, **kwargs):
method producer_acquire_w_index_phase (line 148) | def producer_acquire_w_index_phase(
method producer_commit_w_index (line 165) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
method consumer_wait_w_index_phase (line 169) | def consumer_wait_w_index_phase(
method consumer_release_w_index (line 186) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
class PipelineTmaAsync (line 191) | class PipelineTmaAsync(PipelineTmaAsyncOg):
method create (line 197) | def create(*args, **kwargs):
method producer_acquire (line 204) | def producer_acquire(
method producer_acquire_w_index_phase (line 229) | def producer_acquire_w_index_phase(
method consumer_wait_w_index_phase (line 247) | def consumer_wait_w_index_phase(
method consumer_release_w_index (line 264) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
class PipelineTmaUmma (line 275) | class PipelineTmaUmma(PipelineTmaUmmaOg):
method create (line 281) | def create(*args, **kwargs):
method producer_acquire (line 289) | def producer_acquire(
method producer_acquire_w_index_phase (line 328) | def producer_acquire_w_index_phase(
method consumer_wait_w_index_phase (line 354) | def consumer_wait_w_index_phase(
method consumer_release_w_index (line 371) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
class PipelineUmmaAsync (line 379) | class PipelineUmmaAsync(PipelineUmmaAsyncOg):
method create (line 381) | def create(*args, **kwargs):
method producer_acquire_w_index_phase (line 388) | def producer_acquire_w_index_phase(
method producer_commit_w_index (line 405) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
method consumer_wait_w_index_phase (line 412) | def consumer_wait_w_index_phase(
method consumer_release_w_index (line 429) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
class PipelineAsyncUmma (line 434) | class PipelineAsyncUmma(PipelineAsyncUmmaOg):
method create (line 436) | def create(*args, **kwargs):
method producer_acquire_w_index_phase (line 443) | def producer_acquire_w_index_phase(
method producer_commit_w_index (line 460) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
method consumer_wait_w_index_phase (line 464) | def consumer_wait_w_index_phase(
method consumer_release_w_index (line 481) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
FILE: flash_attn/cute/seqlen_info.py
class SeqlenInfo (line 18) | class SeqlenInfo:
method create (line 25) | def create(
method offset_batch (line 47) | def offset_batch(
class SeqlenInfoQK (line 67) | class SeqlenInfoQK:
method create (line 80) | def create(
method offset_batch_Q (line 132) | def offset_batch_Q(
method offset_batch_K (line 171) | def offset_batch_K(
class SeqlenInfoQKNewK (line 204) | class SeqlenInfoQKNewK:
method create (line 227) | def create(
FILE: flash_attn/cute/sm90_config_search.py
function _divisors (line 20) | def _divisors(n):
function _acc_regs (line 24) | def _acc_regs(M, N, num_wg):
function _check_mma (line 29) | def _check_mma(M, N, num_wg, atom_layout_m, swap_AB):
function _mma_traffic (line 44) | def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False):
function _check_bwd_config (line 61) | def _check_bwd_config(
function find_feasible_bwd_configs (line 174) | def find_feasible_bwd_configs(
function print_bwd_configs (line 224) | def print_bwd_configs(configs, max_results=20):
function _check_fwd_config (line 260) | def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg):
function find_feasible_fwd_configs (line 315) | def find_feasible_fwd_configs(
function print_fwd_configs (line 336) | def print_fwd_configs(configs, max_results=20):
FILE: flash_attn/cute/softmax.py
class Softmax (line 19) | class Softmax(ParamsBase):
method create (line 28) | def create(
method reset (line 38) | def reset(self) -> None:
method _compute_row_max (line 42) | def _compute_row_max(
method _compute_row_sum (line 47) | def _compute_row_sum(
method online_softmax (line 53) | def online_softmax(
method finalize (line 119) | def finalize(
method rescale_O (line 156) | def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
class SoftmaxSm100 (line 170) | class SoftmaxSm100(Softmax):
method create (line 174) | def create(
method update_row_max (line 194) | def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> ...
method update_row_sum (line 213) | def update_row_sum(
method scale_subtract_rowmax (line 223) | def scale_subtract_rowmax(
method apply_exp2_convert (line 238) | def apply_exp2_convert(
method scale_apply_exp2_convert (line 282) | def scale_apply_exp2_convert(
function floor_if_packed (line 333) | def floor_if_packed(
function apply_score_mod_inner (line 344) | def apply_score_mod_inner(
function apply_score_mod_bwd_inner (line 474) | def apply_score_mod_bwd_inner(
FILE: flash_attn/cute/testing.py
class IndexFirstAxis (line 13) | class IndexFirstAxis(torch.autograd.Function):
method forward (line 15) | def forward(ctx, input, indices):
method backward (line 27) | def backward(ctx, grad_output):
class IndexPutFirstAxis (line 44) | class IndexPutFirstAxis(torch.autograd.Function):
method forward (line 46) | def forward(ctx, values, indices, first_axis_dim):
method backward (line 57) | def backward(ctx, grad_output):
function unpad_input (line 66) | def unpad_input(hidden_states, attention_mask, unused_mask=None):
function pad_input (line 89) | def pad_input(hidden_states, indices, batch, seqlen):
function generate_random_padding_mask (line 94) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 124) | def generate_qkv(
function construct_local_mask (line 249) | def construct_local_mask(
function construct_chunk_mask (line 291) | def construct_chunk_mask(
function attention_ref (line 323) | def attention_ref(
function maybe_fake_tensor_mode (line 437) | def maybe_fake_tensor_mode(fake: bool = True):
function is_fake_mode (line 455) | def is_fake_mode() -> bool:
FILE: flash_attn/cute/tile_scheduler.py
class WorkTileInfo (line 23) | class WorkTileInfo(cutlass.utils.WorkTileInfo):
method __new_from_mlir_values__ (line 27) | def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTil...
class TileSchedulerArguments (line 35) | class TileSchedulerArguments(ParamsBase):
class SingleTileScheduler (line 56) | class SingleTileScheduler:
class Params (line 58) | class Params(ParamsBase):
method create (line 68) | def create(
method __init__ (line 81) | def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None,...
method to_underlying_arguments (line 89) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,...
method create (line 93) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileSchedul...
method get_grid_shape (line 105) | def get_grid_shape(
method get_current_work (line 119) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
method initial_work_tile_info (line 130) | def initial_work_tile_info(self, *, loc=None, ip=None):
method prefetch_next_work (line 133) | def prefetch_next_work(self, *, loc=None, ip=None):
method advance_to_next_work (line 136) | def advance_to_next_work(self, *, loc=None, ip=None):
method __extract_mlir_values__ (line 139) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 147) | def __new_from_mlir_values__(self, values):
class StaticPersistentTileScheduler (line 155) | class StaticPersistentTileScheduler:
class Params (line 157) | class Params(ParamsBase):
method create (line 164) | def create(
method __init__ (line 176) | def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=No...
method to_underlying_arguments (line 183) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,...
method create (line 187) | def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentT...
method get_grid_shape (line 196) | def get_grid_shape(
method get_current_work (line 210) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
method initial_work_tile_info (line 220) | def initial_work_tile_info(self, *, loc=None, ip=None):
method prefetch_next_work (line 223) | def prefetch_next_work(self, *, loc=None, ip=None):
method advance_to_next_work (line 226) | def advance_to_next_work(self, *, loc=None, ip=None):
method __extract_mlir_values__ (line 232) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 240) | def __new_from_mlir_values__(self, values):
class SingleTileLPTScheduler (line 251) | class SingleTileLPTScheduler:
class Params (line 253) | class Params(ParamsBase):
method create (line 268) | def create(
method __init__ (line 303) | def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, ...
method to_underlying_arguments (line 311) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,...
method create (line 316) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTSche...
method get_grid_shape (line 322) | def get_grid_shape(
method get_current_work (line 331) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
method initial_work_tile_info (line 351) | def initial_work_tile_info(self, *, loc=None, ip=None):
method prefetch_next_work (line 354) | def prefetch_next_work(self, *, loc=None, ip=None):
method advance_to_next_work (line 357) | def advance_to_next_work(self, *, loc=None, ip=None):
method __extract_mlir_values__ (line 361) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 369) | def __new_from_mlir_values__(self, values):
class SingleTileLPTBwdScheduler (line 377) | class SingleTileLPTBwdScheduler:
class Params (line 379) | class Params(ParamsBase):
method create (line 393) | def create(
method __init__ (line 426) | def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=No...
method to_underlying_arguments (line 433) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,...
method create (line 438) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdS...
method get_grid_shape (line 444) | def get_grid_shape(
method get_current_work (line 453) | def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.Work...
method initial_work_tile_info (line 475) | def initial_work_tile_info(self, *, loc=None, ip=None):
method prefetch_next_work (line 478) | def prefetch_next_work(self, *, loc=None, ip=None):
method advance_to_next_work (line 481) | def advance_to_next_work(self, *, loc=None, ip=None):
method __extract_mlir_values__ (line 485) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 493) | def __new_from_mlir_values__(self, values):
class SingleTileVarlenScheduler (line 501) | class SingleTileVarlenScheduler:
class Params (line 503) | class Params(ParamsBase):
method create (line 520) | def create(
method __init__ (line 547) | def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, ...
method to_underlying_arguments (line 556) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,...
method create (line 560) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenS...
method get_grid_shape (line 566) | def get_grid_shape(
method _get_num_m_blocks (line 581) | def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
method get_current_work (line 604) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
method initial_work_tile_info (line 701) | def initial_work_tile_info(self, *, loc=None, ip=None):
method prefetch_next_work (line 704) | def prefetch_next_work(self, *, loc=None, ip=None):
method advance_to_next_work (line 707) | def advance_to_next_work(self, *, loc=None, ip=None):
method __extract_mlir_values__ (line 711) | def __extract_mlir_values__(self):
method __new_from_mlir_values__ (line 719) | def __new_from_mlir_values__(self, values):
FILE: flash_attn/cute/utils.py
function _compute_base_hash (line 59) | def _compute_base_hash(func: Callable) -> str:
function hash_callable (line 78) | def hash_callable(
function create_softcap_scoremod (line 116) | def create_softcap_scoremod(softcap_val):
function compute_softmax_scale_log2 (line 130) | def compute_softmax_scale_log2(softmax_scale, score_mod):
function compute_fastdiv_mods (line 145) | def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors...
function convert_from_dlpack (line 161) | def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) ->...
function convert_from_dlpack_leading_static (line 171) | def convert_from_dlpack_leading_static(
function make_tiled_copy_A (line 183) | def make_tiled_copy_A(
function make_tiled_copy_B (line 192) | def make_tiled_copy_B(
function mma_make_fragment_A (line 201) | def mma_make_fragment_A(
function mma_make_fragment_B (line 210) | def mma_make_fragment_B(
function get_smem_store_atom (line 219) | def get_smem_store_atom(
function warp_reduce (line 236) | def warp_reduce(
function fmax (line 254) | def fmax(
function fmax_reduce (line 286) | def fmax_reduce(
function fadd_reduce (line 337) | def fadd_reduce(
function atomic_add_fp32 (line 378) | def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=N...
function elem_pointer (line 400) | def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None...
function predicate_k (line 405) | def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
function canonical_warp_group_idx (line 420) | def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
function shuffle_sync (line 448) | def shuffle_sync(
function shl_u32 (line 468) | def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=...
function shr_u32 (line 502) | def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=...
function warp_prefix_sum (line 526) | def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = ...
function cvt_f16x2_f32 (line 541) | def cvt_f16x2_f32(
function cvt_f16 (line 559) | def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
function cvt_f16 (line 563) | def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor:...
function cvt_f16 (line 567) | def cvt_f16(src: cute.Tensor, dst_or_dtype):
function evaluate_polynomial (line 600) | def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=No...
function evaluate_polynomial_2 (line 610) | def evaluate_polynomial_2(
function add_round_down (line 621) | def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ...
function combine_int_frac_ex2 (line 637) | def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=N...
function ex2_emulation (line 664) | def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None...
function ex2_emulation_2 (line 681) | def ex2_emulation_2(
function e2e_asm2 (line 702) | def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Floa...
function domain_offset_aligned (line 749) | def domain_offset_aligned(
function scalar_to_ssa (line 764) | def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
function ssa_to_scalar (line 771) | def ssa_to_scalar(val):
FILE: flash_attn/flash_attn_interface.py
function maybe_contiguous (line 27) | def maybe_contiguous(x):
function _get_block_size_n (line 31) | def _get_block_size_n(device, head_dim, is_dropout, is_causal):
function round_multiple (line 57) | def round_multiple(x, m):
function noop_custom_op_wrapper (line 68) | def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_typ...
function noop_register_fake_wrapper (line 74) | def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
function _flash_attn_forward (line 85) | def _flash_attn_forward(
function _flash_attn_forward_fake (line 118) | def _flash_attn_forward_fake(
function _flash_attn_varlen_forward (line 154) | def _flash_attn_varlen_forward(
function _flash_attn_varlen_forward_fake (line 205) | def _flash_attn_varlen_forward_fake(
function _flash_attn_backward (line 250) | def _flash_attn_backward(
function _flash_attn_backward_fake (line 302) | def _flash_attn_backward_fake(
function _flash_attn_varlen_backward (line 345) | def _flash_attn_varlen_backward(
function _flash_attn_varlen_backward_fake (line 409) | def _flash_attn_varlen_backward_fake(
class FlashAttnQKVPackedFunc (line 458) | class FlashAttnQKVPackedFunc(torch.autograd.Function):
method forward (line 460) | def forward(
method backward (line 508) | def backward(ctx, dout, *args):
class FlashAttnVarlenQKVPackedFunc (line 540) | class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
method forward (line 542) | def forward(
method backward (line 598) | def backward(ctx, dout, *args):
class FlashAttnKVPackedFunc (line 634) | class FlashAttnKVPackedFunc(torch.autograd.Function):
method forward (line 636) | def forward(
method backward (line 687) | def backward(ctx, dout, *args):
class FlashAttnVarlenKVPackedFunc (line 721) | class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
method forward (line 723) | def forward(
method backward (line 787) | def backward(ctx, dout, *args):
class FlashAttnFunc (line 825) | class FlashAttnFunc(torch.autograd.Function):
method forward (line 827) | def forward(
method backward (line 878) | def backward(ctx, dout, *args):
class FlashAttnVarlenFunc (line 911) | class FlashAttnVarlenFunc(torch.autograd.Function):
method forward (line 913) | def forward(
method backward (line 979) | def backward(ctx, dout, *args):
function flash_attn_qkvpacked_func (line 1016) | def flash_attn_qkvpacked_func(
function flash_attn_kvpacked_func (line 1075) | def flash_attn_kvpacked_func(
function flash_attn_func (line 1153) | def flash_attn_func(
function flash_attn_varlen_qkvpacked_func (line 1230) | def flash_attn_varlen_qkvpacked_func(
function flash_attn_varlen_kvpacked_func (line 1296) | def flash_attn_varlen_kvpacked_func(
function flash_attn_varlen_func (line 1388) | def flash_attn_varlen_func(
function flash_attn_with_kvcache (line 1482) | def flash_attn_with_kvcache(
FILE: flash_attn/flash_attn_triton.py
function _fwd_kernel (line 66) | def _fwd_kernel(
function _bwd_preprocess_do_o_dot (line 288) | def _bwd_preprocess_do_o_dot(
function _bwd_store_dk_dv (line 333) | def _bwd_store_dk_dv(
function _bwd_kernel_one_col_block (line 365) | def _bwd_kernel_one_col_block(
function init_to_zero (line 633) | def init_to_zero(name):
function _bwd_kernel (line 668) | def _bwd_kernel(
function _flash_attn_forward (line 812) | def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=...
function _flash_attn_backward (line 894) | def _flash_attn_backward(
class FlashAttnQKVPackedFunc (line 1013) | class FlashAttnQKVPackedFunc(torch.autograd.Function):
method forward (line 1015) | def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
method backward (line 1038) | def backward(ctx, do):
class FlashAttnKVPackedFunc (line 1065) | class FlashAttnKVPackedFunc(torch.autograd.Function):
method forward (line 1067) | def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
method backward (line 1085) | def backward(ctx, do):
class FlashAttnFunc (line 1114) | class FlashAttnFunc(torch.autograd.Function):
method forward (line 1116) | def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
method backward (line 1134) | def backward(ctx, do):
FILE: flash_attn/flash_attn_triton_og.py
function _fwd_kernel (line 19) | def _fwd_kernel(
function _bwd_preprocess (line 121) | def _bwd_preprocess(
function _bwd_kernel (line 145) | def _bwd_kernel(
class _attention (line 248) | class _attention(torch.autograd.Function):
method forward (line 250) | def forward(ctx, q, k, v, sm_scale):
method backward (line 307) | def backward(ctx, do):
FILE: flash_attn/flash_blocksparse_attention.py
class FlashBlocksparseAttention (line 15) | class FlashBlocksparseAttention(nn.Module):
method __init__ (line 26) | def __init__(
method forward (line 48) | def forward(
class FlashBlocksparseMHA (line 154) | class FlashBlocksparseMHA(nn.Module):
method __init__ (line 155) | def __init__(
method forward (line 189) | def forward(
FILE: flash_attn/flash_blocksparse_attn_interface.py
function convert_blockmask (line 7) | def convert_blockmask(blockmask, causal):
function _flash_blocksparse_attn_forward (line 42) | def _flash_blocksparse_attn_forward(
function _flash_blocksparse_attn_backward (line 54) | def _flash_blocksparse_attn_backward(
class FlashBlocksparseAttnFun (line 86) | class FlashBlocksparseAttnFun(torch.autograd.Function):
method forward (line 88) | def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax...
method backward (line 111) | def backward(ctx, dout):
class FlashBlocksparseAttnFunWithS (line 137) | class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
method forward (line 139) | def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax...
method backward (line 162) | def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
function flash_blocksparse_attn_func (line 185) | def flash_blocksparse_attn_func(
FILE: flash_attn/layers/patch_embed.py
class PatchEmbed (line 17) | class PatchEmbed(nn.Module):
method __init__ (line 20) | def __init__(
method forward (line 46) | def forward(self, x):
FILE: flash_attn/layers/rotary.py
function rotate_half (line 14) | def rotate_half(x, interleaved=False):
function apply_rotary_emb_torch (line 23) | def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
class ApplyRotaryEmb (line 38) | class ApplyRotaryEmb(torch.autograd.Function):
method forward (line 40) | def forward(
method backward (line 73) | def backward(ctx, do):
function apply_rotary_emb (line 93) | def apply_rotary_emb(
function _apply_rotary_emb_qkv (line 130) | def _apply_rotary_emb_qkv(
class ApplyRotaryEmbQKV_ (line 194) | class ApplyRotaryEmbQKV_(torch.autograd.Function):
method forward (line 196) | def forward(
method backward (line 223) | def backward(ctx, dqkv):
function apply_rotary_emb_qkv_ (line 236) | def apply_rotary_emb_qkv_(
class ApplyRotaryEmbKV_ (line 267) | class ApplyRotaryEmbKV_(torch.autograd.Function):
method forward (line 270) | def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Unio...
method backward (line 287) | def backward(ctx, dkv):
function apply_rotary_emb_kv_ (line 308) | def apply_rotary_emb_kv_(
class RotaryEmbedding (line 331) | class RotaryEmbedding(torch.nn.Module):
method __init__ (line 349) | def __init__(
method _compute_inv_freq (line 382) | def _compute_inv_freq(self, device=None):
method _update_cos_sin_cache (line 388) | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
method forward (line 429) | def forward(
FILE: flash_attn/losses/cross_entropy.py
class CrossEntropyLoss (line 9) | class CrossEntropyLoss(nn.Module):
method __init__ (line 10) | def __init__(
method forward (line 47) | def forward(self, input, target, precomputed_lse=None):
FILE: flash_attn/models/baichuan.py
function remap_state_dict_hf_baichuan (line 17) | def remap_state_dict_hf_baichuan(state_dict, config):
function baichuan_config_to_gpt2_config (line 115) | def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) ->...
FILE: flash_attn/models/bert.py
function create_mixer_cls (line 57) | def create_mixer_cls(config, cross_attn=False, return_residual=False):
function create_mlp_cls (line 80) | def create_mlp_cls(config, layer_idx=None, return_residual=False):
function create_block (line 116) | def create_block(config, layer_idx=None):
function _init_weights (line 141) | def _init_weights(module, initializer_range=0.02):
class BertEncoder (line 152) | class BertEncoder(nn.Module):
method __init__ (line 153) | def __init__(self, config: BertConfig):
method forward (line 160) | def forward(self, hidden_states, key_padding_mask=None, subset_mask=No...
class BertPooler (line 215) | class BertPooler(nn.Module):
method __init__ (line 216) | def __init__(self, config):
method forward (line 225) | def forward(self, hidden_states, pool=True):
class BertPredictionHeadTransform (line 234) | class BertPredictionHeadTransform(nn.Module):
method __init__ (line 235) | def __init__(self, config):
method forward (line 253) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class BertLMPredictionHead (line 265) | class BertLMPredictionHead(nn.Module):
method __init__ (line 266) | def __init__(self, config):
method forward (line 279) | def forward(self, hidden_states):
class BertPreTrainingHeads (line 285) | class BertPreTrainingHeads(nn.Module):
method __init__ (line 286) | def __init__(self, config):
method forward (line 291) | def forward(self, sequence_output, pooled_output):
class BertPreTrainedModel (line 297) | class BertPreTrainedModel(nn.Module):
method __init__ (line 302) | def __init__(self, config, *inputs, **kwargs):
method from_pretrained (line 315) | def from_pretrained(cls, model_name, config, *inputs, **kwargs):
class BertModel (line 340) | class BertModel(BertPreTrainedModel):
method __init__ (line 341) | def __init__(self, config: BertConfig, add_pooling_layer=True):
method forward (line 367) | def forward(
class BertForPreTraining (line 427) | class BertForPreTraining(BertPreTrainedModel):
method __init__ (line 428) | def __init__(self, config: BertConfig):
method tie_weights (line 456) | def tie_weights(self):
method forward (line 459) | def forward(
function remap_state_dict (line 524) | def remap_state_dict(state_dict, config: PretrainedConfig):
function inv_remap_state_dict (line 637) | def inv_remap_state_dict(state_dict, config: PretrainedConfig):
FILE: flash_attn/models/bigcode.py
function remap_state_dict_hf_bigcode (line 10) | def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
function inv_remap_state_dict_hf_bigcode (line 112) | def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
function bigcode_config_to_gpt2_config (line 206) | def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> G...
FILE: flash_attn/models/btlm.py
function remap_state_dict_hf_btlm (line 17) | def remap_state_dict_hf_btlm(state_dict, config):
function btlm_config_to_gpt2_config (line 78) | def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Con...
FILE: flash_attn/models/falcon.py
function remap_state_dict_hf_falcon (line 13) | def remap_state_dict_hf_falcon(state_dict, config):
function falcon_config_to_gpt2_config (line 106) | def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Con...
FILE: flash_attn/models/gpt.py
function create_mixer_cls (line 62) | def create_mixer_cls(config, layer_idx=None, process_group=None, device=...
function create_mlp_cls (line 123) | def create_mlp_cls(config, layer_idx=None, process_group=None, device=No...
function create_block (line 262) | def create_block(config, layer_idx=None, process_group=None, device=None...
class GPTPreTrainedModel (line 311) | class GPTPreTrainedModel(nn.Module):
method __init__ (line 316) | def __init__(self, config, *inputs, **kwargs):
method from_pretrained (line 329) | def from_pretrained(
function _init_weights (line 380) | def _init_weights(
class GPTModel (line 409) | class GPTModel(GPTPreTrainedModel):
method __init__ (line 410) | def __init__(self, config: GPT2Config, process_group=None, device=None...
method tie_weights (line 504) | def tie_weights(self):
method allocate_inference_cache (line 508) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
method forward (line 514) | def forward(self, input_ids, position_ids=None, inference_params=None):
class GPTLMHeadModel (line 577) | class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
method __init__ (line 578) | def __init__(self, config: GPT2Config, process_group=None, device=None...
method tie_weights (line 624) | def tie_weights(self):
method allocate_inference_cache (line 630) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
method forward (line 635) | def forward(self, input_ids, position_ids=None, inference_params=None,...
method load_state_dict (line 671) | def load_state_dict(self, state_dict, strict=True):
function shard_state_dict_tp (line 698) | def shard_state_dict_tp(state_dict, config, world_size, rank):
function combine_state_dicts_tp (line 814) | def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], c...
function remap_state_dict_hf_gpt2 (line 930) | def remap_state_dict_hf_gpt2(state_dict, config):
function remap_state_dict_megatron (line 987) | def remap_state_dict_megatron(state_dict, config):
FILE: flash_attn/models/gpt_neox.py
function remap_state_dict_hf_gpt_neox (line 13) | def remap_state_dict_hf_gpt_neox(state_dict, config):
function gpt_neox_config_to_gpt2_config (line 101) | def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GP...
FILE: flash_attn/models/gptj.py
function remap_state_dict_hf_gptj (line 12) | def remap_state_dict_hf_gptj(state_dict, config):
function gptj_config_to_gpt2_config (line 82) | def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
FILE: flash_attn/models/llama.py
function remap_state_dict_meta_llama (line 19) | def remap_state_dict_meta_llama(
function remap_state_dict_hf_llama (line 115) | def remap_state_dict_hf_llama(
function inv_remap_state_dict_hf_llama (line 219) | def inv_remap_state_dict_hf_llama(
function config_from_meta_checkpoint (line 329) | def config_from_meta_checkpoint(
function config_from_hf_checkpoint (line 368) | def config_from_hf_checkpoint(
function config_from_checkpoint (line 374) | def config_from_checkpoint(
function state_dicts_from_checkpoint (line 383) | def state_dicts_from_checkpoint(
function llama_config_to_gpt2_config (line 393) | def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
FILE: flash_attn/models/opt.py
function remap_state_dict_hf_opt (line 12) | def remap_state_dict_hf_opt(state_dict, config):
function opt_config_to_gpt2_config (line 90) | def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
FILE: flash_attn/models/vit.py
function create_mixer_cls (line 28) | def create_mixer_cls(
function create_mlp_cls (line 43) | def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
function create_block (line 52) | def create_block(
class VisionTransformer (line 97) | class VisionTransformer(nn.Module):
method __init__ (line 103) | def __init__(
method init_weights (line 240) | def init_weights(self, mode=""):
method _init_weights (line 247) | def _init_weights(self, m):
method no_weight_decay (line 252) | def no_weight_decay(self):
method _pos_embed (line 255) | def _pos_embed(self, x):
method forward_features (line 270) | def forward_features(self, x, all_tokens=True):
method forward_head (line 317) | def forward_head(self, x, pre_logits: bool = False):
method forward (line 322) | def forward(self, x):
method load_state_dict (line 327) | def load_state_dict(self, state_dict, strict=True):
function init_weights_vit_timm (line 356) | def init_weights_vit_timm(module: nn.Module, name: str = ""):
function vit_base_patch16_224 (line 366) | def vit_base_patch16_224(pretrained=False, **kwargs):
FILE: flash_attn/modules/block.py
class Block (line 21) | class Block(nn.Module):
method __init__ (line 22) | def __init__(
method allocate_inference_cache (line 105) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
method forward (line 108) | def forward(
class ParallelBlock (line 259) | class ParallelBlock(nn.Module):
method __init__ (line 264) | def __init__(
method allocate_inference_cache (line 332) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
method forward (line 335) | def forward(
FILE: flash_attn/modules/embedding.py
class GPT2Embeddings (line 11) | class GPT2Embeddings(nn.Module):
method __init__ (line 12) | def __init__(
method forward (line 47) | def forward(self, input_ids, position_ids=None):
class BertEmbeddings (line 64) | class BertEmbeddings(nn.Module):
method __init__ (line 65) | def __init__(
method forward (line 93) | def forward(self, input_ids, position_ids=None, token_type_ids=None):
class VocabParallelEmbedding (line 114) | class VocabParallelEmbedding(nn.Embedding):
method __init__ (line 115) | def __init__(self, num_embeddings, *args, process_group=None, padding_...
method forward (line 130) | def forward(self, input: Tensor) -> Tensor:
class ColumnParallelEmbedding (line 146) | class ColumnParallelEmbedding(nn.Embedding):
method __init__ (line 147) | def __init__(self, num_embeddings, embedding_dim, *args, process_group...
class ParallelGPT2Embeddings (line 161) | class ParallelGPT2Embeddings(nn.Module):
method __init__ (line 162) | def __init__(
method forward (line 193) | def forward(self, input_ids, position_ids=None, combine_batch_seqlen_d...
FILE: flash_attn/modules/mha.py
function get_alibi_slopes (line 37) | def get_alibi_slopes(nheads):
class FlashSelfAttention (line 53) | class FlashSelfAttention(nn.Module):
method __init__ (line 64) | def __init__(
method forward (line 83) | def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
class FlashCrossAttention (line 133) | class FlashCrossAttention(nn.Module):
method __init__ (line 144) | def __init__(
method forward (line 163) | def forward(
class SelfAttention (line 230) | class SelfAttention(nn.Module):
method __init__ (line 241) | def __init__(self, causal=False, softmax_scale=None, attention_dropout...
method forward (line 247) | def forward(self, qkv, causal=None, key_padding_mask=None):
class CrossAttention (line 282) | class CrossAttention(nn.Module):
method __init__ (line 293) | def __init__(self, causal=False, softmax_scale=None, attention_dropout...
method forward (line 299) | def forward(self, q, kv, causal=None, key_padding_mask=None):
function _update_kv_cache (line 344) | def _update_kv_cache(kv, inference_params, layer_idx):
class MHA (line 373) | class MHA(nn.Module):
method __init__ (line 376) | def __init__(
method allocate_inference_cache (line 483) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
method _update_kv_cache (line 496) | def _update_kv_cache(self, kv, inference_params):
method _apply_rotary_update_kvcache_attention (line 502) | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_para...
method _update_kvcache_attention (line 542) | def _update_kvcache_attention(self, q, kv, inference_params):
method forward (line 573) | def forward(
class ParallelMHA (line 707) | class ParallelMHA(nn.Module):
method __init__ (line 710) | def __init__(
method allocate_inference_cache (line 824) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
method _update_kv_cache (line 837) | def _update_kv_cache(self, kv, inference_params):
method _apply_rotary_update_kvcache_attention (line 842) | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_para...
method _update_kvcache_attention (line 882) | def _update_kvcache_attention(self, q, kv, inference_params):
method forward (line 910) | def forward(self, x, seqlen=None, inference_params=None, **kwargs):
FILE: flash_attn/modules/mlp.py
class Mlp (line 25) | class Mlp(nn.Module):
method __init__ (line 26) | def __init__(
method forward (line 47) | def forward(self, x):
class ParallelMLP (line 54) | class ParallelMLP(nn.Module):
method __init__ (line 55) | def __init__(
method forward (line 92) | def forward(self, x):
class GatedMlp (line 99) | class GatedMlp(nn.Module):
method __init__ (line 100) | def __init__(
method forward (line 125) | def forward(self, x):
class ParallelGatedMlp (line 139) | class ParallelGatedMlp(nn.Module):
method __init__ (line 142) | def __init__(
method forward (line 183) | def forward(self, x):
FILE: flash_attn/ops/activations.py
function bias_gelu (line 16) | def bias_gelu(y, bias):
function bias_gelu_back (line 25) | def bias_gelu_back(g, y, bias):
class GeLUFunction (line 37) | class GeLUFunction(torch.autograd.Function):
method forward (line 40) | def forward(ctx, input, bias):
method backward (line 45) | def backward(ctx, grad_output):
function gelu_fwd (line 57) | def gelu_fwd(x):
function gelu_bwd (line 65) | def gelu_bwd(g, x):
class FastGeLUFunction (line 74) | class FastGeLUFunction(torch.autograd.Function):
method forward (line 77) | def forward(ctx, input):
method backward (line 82) | def backward(ctx, grad_output):
function relu_bwd (line 92) | def relu_bwd(g, x):
function sqrelu_fwd (line 97) | def sqrelu_fwd(x):
function sqrelu_bwd (line 103) | def sqrelu_bwd(g, x):
class SwiGLUFunction (line 123) | class SwiGLUFunction(torch.autograd.Function):
method forward (line 126) | def forward(ctx, x, y):
method backward (line 131) | def backward(ctx, dout):
FILE: flash_attn/ops/fused_dense.py
class FusedDenseFunc (line 27) | class FusedDenseFunc(torch.autograd.Function):
method forward (line 30) | def forward(
method backward (line 71) | def backward(ctx, grad_output, *args):
function fused_dense_func (line 118) | def fused_dense_func(
class FusedDense (line 139) | class FusedDense(nn.Linear):
method __init__ (line 140) | def __init__(
method forward (line 152) | def forward(self, x, process_group=None):
class ColumnParallelLinear (line 166) | class ColumnParallelLinear(nn.Linear):
method __init__ (line 167) | def __init__(
method forward (line 193) | def forward(self, x):
class RowParallelLinear (line 206) | class RowParallelLinear(nn.Linear):
method __init__ (line 207) | def __init__(
method forward (line 239) | def forward(self, x):
class FusedMLPFunc (line 249) | class FusedMLPFunc(torch.autograd.Function):
method forward (line 252) | def forward(
method backward (line 349) | def backward(ctx, grad_output, *args):
function fused_mlp_func (line 475) | def fused_mlp_func(
class FusedMLP (line 531) | class FusedMLP(nn.Module):
method __init__ (line 532) | def __init__(
method forward (line 580) | def forward(self, x, process_group=None):
class ParallelFusedMLP (line 613) | class ParallelFusedMLP(nn.Module):
method __init__ (line 614) | def __init__(
method forward (line 664) | def forward(self, x):
FILE: flash_attn/ops/layer_norm.py
function maybe_align (line 9) | def maybe_align(x, alignment_in_bytes=16):
function _dropout_add_layer_norm_forward (line 16) | def _dropout_add_layer_norm_forward(
function _dropout_add_layer_norm_backward (line 55) | def _dropout_add_layer_norm_backward(
function _dropout_add_layer_norm_subset_forward (line 110) | def _dropout_add_layer_norm_subset_forward(
function _dropout_add_layer_norm_subset_backward (line 153) | def _dropout_add_layer_norm_subset_backward(
function _dropout_add_layer_norm_parallel_residual_forward (line 212) | def _dropout_add_layer_norm_parallel_residual_forward(
function _dropout_add_layer_norm_parallel_residual_backward (line 257) | def _dropout_add_layer_norm_parallel_residual_backward(
class DropoutAddLayerNormFn (line 311) | class DropoutAddLayerNormFn(torch.autograd.Function):
method forward (line 313) | def forward(
method backward (line 374) | def backward(ctx, dz, *args):
class DropoutAddLayerNormSubsetFn (line 416) | class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
method forward (line 418) | def forward(
method backward (line 483) | def backward(ctx, dz, *args):
class DropoutAddLayerNormParallelResidualFn (line 531) | class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
method forward (line 533) | def forward(
method backward (line 605) | def backward(ctx, dz0, dz1, *args):
function layer_norm (line 657) | def layer_norm(x, weight, bias, epsilon):
function dropout_add_layer_norm (line 661) | def dropout_add_layer_norm(
function dropout_add_layer_norm_subset (line 693) | def dropout_add_layer_norm_subset(
function dropout_add_layer_norm_parallel_residual (line 731) | def dropout_add_layer_norm_parallel_residual(
class DropoutAddLayerNorm (line 765) | class DropoutAddLayerNorm(torch.nn.Module):
method __init__ (line 766) | def __init__(
method reset_parameters (line 786) | def reset_parameters(self):
method forward (line 790) | def forward(self, x0, residual=None):
FILE: flash_attn/ops/rms_norm.py
function rms_norm (line 14) | def rms_norm(x, weight, epsilon):
function dropout_add_rms_norm (line 20) | def dropout_add_rms_norm(
function dropout_add_rms_norm_subset (line 52) | def dropout_add_rms_norm_subset(
function dropout_add_rms_norm_parallel_residual (line 90) | def dropout_add_rms_norm_parallel_residual(
class RMSNorm (line 124) | class RMSNorm(torch.nn.Module):
method __init__ (line 125) | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
method reset_parameters (line 133) | def reset_parameters(self):
method forward (line 136) | def forward(self, x):
class DropoutAddRMSNorm (line 140) | class DropoutAddRMSNorm(torch.nn.Module):
method __init__ (line 141) | def __init__(
method reset_parameters (line 161) | def reset_parameters(self):
method forward (line 164) | def forward(self, x0, residual=None):
FILE: flash_attn/ops/triton/cross_entropy.py
function cross_entropy_fwd_kernel (line 25) | def cross_entropy_fwd_kernel(
function cross_entropy_bwd_kernel (line 104) | def cross_entropy_bwd_kernel(
class CrossEntropyLoss (line 149) | class CrossEntropyLoss(torch.autograd.Function):
method forward (line 152) | def forward(
method backward (line 258) | def backward(ctx, grad_losses, grad_z_losses):
function cross_entropy_loss (line 292) | def cross_entropy_loss(
FILE: flash_attn/ops/triton/k_activations.py
class Activation (line 19) | class Activation(str, Enum):
function get_triton_activation_kernel (line 27) | def get_triton_activation_kernel(activation: Optional[Activation]):
function get_triton_activation_bwd_kernel (line 41) | def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
function tanh (line 56) | def tanh(x):
function cosh (line 62) | def cosh(x):
function relu (line 72) | def relu(x):
function relu_grad (line 83) | def relu_grad(x):
function squared_relu (line 93) | def squared_relu(x):
function squared_relu_grad (line 104) | def squared_relu_grad(x):
function leaky_relu (line 110) | def leaky_relu(x):
function leaky_relu_grad (line 122) | def leaky_relu_grad(x):
function gelu (line 133) | def gelu(x):
function gelu_grad (line 139) | def gelu_grad(x):
function gelu_approx (line 146) | def gelu_approx(x):
function gelu_approx_grad (line 156) | def gelu_approx_grad(x):
FILE: flash_attn/ops/triton/layer_norm.py
function maybe_contiguous_lastdim (line 23) | def maybe_contiguous_lastdim(x):
function maybe_contiguous (line 27) | def maybe_contiguous(x):
function triton_autotune_configs (line 31) | def triton_autotune_configs():
function layer_norm_ref (line 44) | def layer_norm_ref(
function rms_norm_ref (line 104) | def rms_norm_ref(
function _layer_norm_fwd_1pass_kernel (line 174) | def _layer_norm_fwd_1pass_kernel(
function _layer_norm_fwd (line 290) | def _layer_norm_fwd(
function _layer_norm_fwd_impl (line 355) | def _layer_norm_fwd_impl(
function _layer_norm_bwd_kernel (line 485) | def _layer_norm_bwd_kernel(
function _layer_norm_bwd (line 643) | def _layer_norm_bwd(
function _layer_norm_bwd_impl (line 702) | def _layer_norm_bwd_impl(
class LayerNormFn (line 846) | class LayerNormFn(torch.autograd.Function):
method forward (line 849) | def forward(
method backward (line 951) | def backward(ctx, dy, *args):
function layer_norm_fn (line 1010) | def layer_norm_fn(
function rms_norm_fn (line 1052) | def rms_norm_fn(
class RMSNorm (line 1093) | class RMSNorm(torch.nn.Module):
method __init__ (line 1095) | def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered...
method reset_parameters (line 1109) | def reset_parameters(self):
method forward (line 1115) | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=Fa...
class LayerNormLinearFn (line 1129) | class LayerNormLinearFn(torch.autograd.Function):
method forward (line 1133) | def forward(
method backward (line 1187) | def backward(ctx, dout, *args):
function layer_norm_linear_fn (line 1229) | def layer_norm_linear_fn(
FILE: flash_attn/ops/triton/linear.py
function init_to_zero (line 22) | def init_to_zero(name):
function get_configs_io_bound (line 26) | def get_configs_io_bound():
function kernel_fwd (line 131) | def kernel_fwd(
function triton_linear_act (line 258) | def triton_linear_act(
function kernel_bwd (line 428) | def kernel_bwd(
function triton_dgrad_act (line 529) | def triton_dgrad_act(
FILE: flash_attn/ops/triton/mlp.py
class FusedDenseSqreluDenseFunc (line 13) | class FusedDenseSqreluDenseFunc(torch.autograd.Function):
method forward (line 16) | def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
method backward (line 66) | def backward(ctx, grad_output):
class FusedDenseSqreluDense (line 116) | class FusedDenseSqreluDense(nn.Module):
method __init__ (line 117) | def __init__(
method forward (line 145) | def forward(self, x):
FILE: flash_attn/ops/triton/rotary.py
function rotary_kernel (line 13) | def rotary_kernel(
function apply_rotary (line 102) | def apply_rotary(
FILE: flash_attn/utils/benchmark.py
function benchmark_forward (line 8) | def benchmark_forward(
function benchmark_backward (line 30) | def benchmark_backward(
function benchmark_combined (line 72) | def benchmark_combined(
function benchmark_fwd_bwd (line 117) | def benchmark_fwd_bwd(
function benchmark_all (line 154) | def benchmark_all(
function pytorch_profiler (line 202) | def pytorch_profiler(
function benchmark_memory (line 258) | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
FILE: flash_attn/utils/distributed.py
function all_gather_raw (line 18) | def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op...
function reduce_scatter_raw (line 30) | def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, asyn...
function all_reduce_raw (line 43) | def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op...
class AllGatherFunc (line 49) | class AllGatherFunc(torch.autograd.Function):
method forward (line 53) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
method backward (line 59) | def backward(ctx, grad_output: Tensor):
class ReduceScatterFunc (line 68) | class ReduceScatterFunc(torch.autograd.Function):
method forward (line 72) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
method backward (line 78) | def backward(ctx, grad_output: Tensor):
class AllReduceFunc (line 87) | class AllReduceFunc(torch.autograd.Function):
method forward (line 91) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
method backward (line 97) | def backward(ctx, grad_output: Tensor):
function sync_shared_params (line 105) | def sync_shared_params(model: torch.nn.Module, process_group: ProcessGro...
function allreduce_sequence_parallel_grad (line 120) | def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_gro...
function get_dim_for_local_rank (line 135) | def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, m...
FILE: flash_attn/utils/generation.py
class InferenceParams (line 24) | class InferenceParams:
method reset (line 35) | def reset(self, max_seqlen, max_batch_size):
function modify_logits_for_top_k_filtering (line 45) | def modify_logits_for_top_k_filtering(logits, top_k):
function modify_logits_for_top_p_filtering (line 53) | def modify_logits_for_top_p_filtering(logits, top_p):
function sample (line 69) | def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
function decode (line 99) | def decode(
function sample_speculative (line 209) | def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_...
function decode_speculative (line 269) | def decode_speculative(
class GenerationMixin (line 566) | class GenerationMixin:
method allocate_inference_cache (line 567) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
method generate (line 570) | def generate(
function allocate_inference_cache (line 589) | def allocate_inference_cache(
class DecodingCGCache (line 606) | class DecodingCGCache:
function update_graph_cache (line 618) | def update_graph_cache(
function capture_graph (line 693) | def capture_graph(
FILE: flash_attn/utils/library.py
function triton_op (line 10) | def triton_op(
FILE: flash_attn/utils/pretrained.py
function state_dict_from_pretrained (line 15) | def state_dict_from_pretrained(model_name, device=None, dtype=None):
FILE: flash_attn/utils/testing.py
function generate_random_padding_mask (line 11) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 34) | def generate_qkv(
function construct_local_mask (line 159) | def construct_local_mask(
function construct_chunk_mask (line 195) | def construct_chunk_mask(
function attention_ref (line 228) | def attention_ref(
FILE: flash_attn/utils/torch.py
function custom_amp_decorator (line 5) | def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
FILE: hopper/benchmark_attn.py
function time_fwd (line 41) | def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
function flops (line 62) | def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=...
function convert_to_cudnn_type (line 76) | def convert_to_cudnn_type(torch_type):
function cudnn_spda_setup (line 91) | def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1):
function cudnn_spda_bwd_setup (line 146) | def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_l...
FILE: hopper/benchmark_flash_attention_fp8.py
function convert_to_cudnn_type (line 34) | def convert_to_cudnn_type(torch_type):
function cudnn_spda_setup (line 52) | def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
function attention_pytorch (line 173) | def attention_pytorch(qkv, dropout_p=0.0, causal=True):
function flops (line 201) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
function efficiency (line 206) | def efficiency(flop, time):
function time_fwd (line 209) | def time_fwd(func, *args, **kwargs):
FILE: hopper/benchmark_split_kv.py
function round_up_to_power_of_2 (line 10) | def round_up_to_power_of_2(x):
function timeit (line 15) | def timeit(fn, *args, **kwargs):
function main (line 35) | def main():
FILE: hopper/block.h
function namespace (line 7) | namespace flash {
function CUTLASS_DEVICE (line 104) | static
function CUTLASS_DEVICE (line 121) | static
FILE: hopper/copy_sm90_bulk_reduce.hpp
type cute (line 9) | namespace cute
type SM90_BULK_REDUCE_ADD (line 14) | struct SM90_BULK_REDUCE_ADD
method CUTE_HOST_DEVICE (line 16) | CUTE_HOST_DEVICE static void
method CUTE_HOST_DEVICE (line 31) | CUTE_HOST_DEVICE static void
FILE: hopper/epilogue_bwd.hpp
type flash (line 17) | namespace flash {
type CollectiveEpilogueBwd (line 23) | struct CollectiveEpilogueBwd {
type TensorStorage (line 85) | struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
type Arguments (line 105) | struct Arguments {
type Params (line 121) | struct Params {
method Params (line 133) | static Params
method CUTLASS_DEVICE (line 156) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 165) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 273) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 279) | CUTLASS_DEVICE void
type CollectiveEpilogueBwdGQA (line 323) | struct CollectiveEpilogueBwdGQA {
type TensorStorageTMA (line 354) | struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
type TensorStorageSTG (line 357) | struct TensorStorageSTG {
type Arguments (line 366) | struct Arguments {
type Params (line 382) | struct Params {
method Params (line 397) | static Params
method CUTLASS_DEVICE (line 410) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 415) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 520) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 525) | CUTLASS_DEVICE void
FILE: hopper/epilogue_fwd.hpp
type flash (line 19) | namespace flash {
type CollectiveEpilogueFwd (line 25) | struct CollectiveEpilogueFwd {
type TensorStorage (line 106) | struct TensorStorage : cute::aligned_struct<128> {
type Arguments (line 122) | struct Arguments {
type Params (line 138) | struct Params {
method Params (line 160) | static Params
method CUTLASS_DEVICE (line 206) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 214) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 404) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 410) | CUTLASS_DEVICE void
FILE: hopper/flash.h
type Qkv_params (line 12) | struct Qkv_params {
function Qkv_params (line 37) | struct Flash_fwd_params : public Qkv_params {
function Flash_fwd_params (line 172) | struct Flash_bwd_params : public Flash_fwd_params {
FILE: hopper/flash_api.cpp
function PyObject (line 24) | PyObject* PyInit__C(void)
function make_cuda_guard_from_tensor (line 45) | inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor&...
function set_params_fprop (line 50) | void set_params_fprop(Flash_fwd_params ¶ms,
function set_params_dgrad (line 170) | void set_params_dgrad(Flash_bwd_params ¶ms,
function run_mha_fwd_constexpr (line 256) | void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) {
function run_mha_fwd (line 367) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
function run_mha_fwd_combine (line 387) | void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, ...
function get_pagedkv_tma (line 415) | inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
function get_pack_gqa (line 426) | inline bool get_pack_gqa(Flash_fwd_params const& params) {
function get_num_splits (line 442) | inline int get_num_splits(Flash_fwd_params const& params) {
function get_max_headdim (line 473) | inline int get_max_headdim() {
function round_up_headdim (line 492) | inline int round_up_headdim(int head_size) {
function round_up_headdimv (line 511) | inline int round_up_headdimv(int head_size) {
function mha_fwd_get_scheduler_metadata (line 521) | at::Tensor
function mha_fwd (line 672) | std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
function run_mha_bwd (line 1201) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
function run_mha_bwd_constexpr (line 1206) | void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) {
function run_mha_bwd (line 1246) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
function mha_bwd (line 1267) | std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> m...
function mha_combine (line 1570) | std::tuple<at::Tensor, at::Tensor>
function TORCH_LIBRARY (line 1673) | TORCH_LIBRARY(flash_attn_3, m) {
function TORCH_LIBRARY_IMPL (line 1764) | TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
FILE: hopper/flash_api_stable.cpp
function make_device_guard (line 36) | inline tsa::DeviceGuard make_device_guard(const Tensor& t) {
function initVectors (line 42) | void initVectors() {
function initDeviceProperty (line 56) | void initDeviceProperty(int device_index) {
function cudaDeviceProp (line 67) | cudaDeviceProp* get_device_prop() {
function PyObject (line 87) | PyObject* PyInit__C(void)
function set_params_fprop (line 114) | void set_params_fprop(Flash_fwd_params ¶ms,
function set_params_dgrad (line 235) | void set_params_dgrad(Flash_bwd_params ¶ms,
function run_mha_fwd_constexpr (line 321) | void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) {
function run_mha_fwd (line 432) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
function run_mha_fwd_combine (line 452) | void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, ...
function get_pagedkv_tma (line 480) | inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
function get_pack_gqa (line 491) | inline bool get_pack_gqa(Flash_fwd_params const& params) {
function get_num_splits (line 507) | inline int get_num_splits(Flash_fwd_params const& params) {
function get_max_headdim (line 538) | inline int get_max_headdim() {
function round_up_headdim (line 557) | inline int round_up_headdim(int head_size) {
function round_up_headdimv (line 576) | inline int round_up_headdimv(int head_size) {
function Tensor (line 586) | Tensor
function mha_fwd (line 741) | std::tuple<Tensor, Tensor, Tensor, Tensor>
function run_mha_bwd (line 1272) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
function run_mha_bwd_constexpr (line 1277) | void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) {
function run_mha_bwd (line 1317) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
function mha_bwd (line 1338) | std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> mha_bwd(
function mha_combine (line 1648) | std::tuple<Tensor, Tensor>
function boxed_mha_fwd (line 1755) | void boxed_mha_fwd(
function boxed_mha_bwd (line 1804) | void boxed_mha_bwd(
function boxed_mha_combine (line 1841) | void boxed_mha_combine(
function boxed_mha_fwd_get_scheduler_metadata (line 1857) | void boxed_mha_fwd_get_scheduler_metadata(
function STABLE_TORCH_LIBRARY (line 1892) | STABLE_TORCH_LIBRARY(flash_attn_3, m) {
function STABLE_TORCH_LIBRARY_IMPL (line 1983) | STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
FILE: hopper/flash_attn_interface.py
function maybe_contiguous (line 30) | def maybe_contiguous(x):
function round_multiple (line 34) | def round_multiple(x, m):
function round_up_headdim (line 38) | def round_up_headdim(head_size: int) -> int:
function _flash_attn_forward (line 60) | def _flash_attn_forward(
function _flash_attn_forward_fake (line 154) | def _flash_attn_forward_fake(
function _flash_attn_backward (line 259) | def _flash_attn_backward(
function _flash_attn_backward_fake (line 313) | def _flash_attn_backward_fake(
function setup_context (line 410) | def setup_context(ctx, inputs, output):
function _backward (line 422) | def _backward(ctx, dout, *grads):
class FlashAttnQKVPackedFunc (line 453) | class FlashAttnQKVPackedFunc(torch.autograd.Function):
method forward (line 455) | def forward(
method backward (line 514) | def backward(ctx, dout, *args):
class FlashAttnFunc (line 552) | class FlashAttnFunc(torch.autograd.Function):
method forward (line 555) | def forward(
method backward (line 611) | def backward(ctx, dout, *args):
class FlashAttnVarlenFunc (line 642) | class FlashAttnVarlenFunc(torch.autograd.Function):
method forward (line 645) | def forward(
method backward (line 713) | def backward(ctx, dout, *args):
function flash_attn_qkvpacked_func (line 747) | def flash_attn_qkvpacked_func(
function flash_attn_func (line 809) | def flash_attn_func(
function flash_attn_varlen_func (line 890) | def flash_attn_varlen_func(
function flash_attn_combine (line 938) | def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
function flash_attn_with_kvcache (line 942) | def flash_attn_with_kvcache(
function get_scheduler_metadata (line 1106) | def get_scheduler_metadata(
FILE: hopper/flash_bwd_kernel_sm80.h
function namespace (line 16) | namespace flash {
type Params (line 80) | struct Params {
function EpilogueParams (line 82) | EpilogueParams epilogue{}
function TileSchedulerParams (line 84) | TileSchedulerParams scheduler{}
function Params (line 92) | static
FILE: hopper/flash_bwd_kernel_sm90.h
function namespace (line 20) | namespace flash {
type Params (line 100) | struct Params {
function EpilogueParams (line 102) | EpilogueParams epilogue{}
function TileSchedulerParams (line 104) | TileSchedulerParams scheduler{}
function Params (line 112) | static
FILE: hopper/flash_bwd_launch_template.h
function dim3 (line 74) | dim3 grid_m(num_m_block, params.h, params.b);
function typename (line 230) | typename PostprocessKernel::Arguments postprocess_args {
FILE: hopper/flash_bwd_postprocess_kernel.h
function namespace (line 18) | namespace flash {
FILE: hopper/flash_bwd_preprocess_kernel.h
function namespace (line 17) | namespace flash {
FILE: hopper/flash_fwd_combine_kernel.h
function namespace (line 20) | namespace flash {
FILE: hopper/flash_fwd_combine_launch_template.h
function typename (line 28) | typename CombineKernel::Arguments args {
function run_mha_fwd_combine_ (line 56) | void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream,...
FILE: hopper/flash_fwd_kernel_sm80.h
function namespace (line 18) | namespace flash {
type Arguments (line 91) | struct Arguments {
function EpilogueArguments (line 93) | EpilogueArguments epilogue{}
function TileSchedulerArguments (line 95) | TileSchedulerArguments scheduler{}
type Params (line 99) | struct Params {
function EpilogueParams (line 101) | EpilogueParams epilogue{}
function TileSchedulerParams (line 103) | TileSchedulerParams scheduler{}
function Params (line 111) | static
FILE: hopper/flash_fwd_kernel_sm90.h
function namespace (line 23) | namespace flash {
type PipelineStorage (line 105) | struct PipelineStorage
type Arguments (line 122) | struct Arguments {
function EpilogueArguments (line 124) | EpilogueArguments epilogue{}
function TileSchedulerArguments (line 126) | TileSchedulerArguments scheduler{}
type Params (line 130) | struct Params {
function EpilogueParams (line 132) | EpilogueParams epilogue{}
function TileSchedulerParams (line 134) | TileSchedulerParams scheduler{}
function dim3 (line 167) | static dim3
function dim3 (line 172) | static dim3
function CUTLASS_DEVICE (line 177) | CUTLASS_DEVICE
FILE: hopper/flash_fwd_launch_template.h
function typename (line 93) | typename CollectiveMainloop::Arguments mainloop_args {
FILE: hopper/generate_kernels.py
class Kernel (line 84) | class Kernel:
method template (line 96) | def template(self) -> str:
method filename (line 127) | def filename(self) -> str:
function get_all_kernels (line 131) | def get_all_kernels() -> List[Kernel]:
function batch_hdim (line 148) | def batch_hdim(kernels_all) -> List[KERNEL_BATCH]:
function batch_softcap (line 166) | def batch_softcap(kernels_all) -> List[KERNEL_BATCH]:
function write_kernel (line 187) | def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
function main (line 195) | def main(output_dir: Optional[str]) -> None:
FILE: hopper/heuristics.h
function should_pack_gqa (line 9) | inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_k...
function num_splits_heuristic (line 25) | inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_...
FILE: hopper/mainloop_bwd_sm80.hpp
type flash (line 20) | namespace flash {
type CollectiveMainloopBwdSm80 (line 29) | struct CollectiveMainloopBwdSm80 {
type TensorStorageSharedQV (line 252) | struct TensorStorageSharedQV : cute::aligned_struct<128> {
type TensorStorageSeparateQV (line 265) | struct TensorStorageSeparateQV : cute::aligned_struct<128> {
type Arguments (line 279) | struct Arguments {
type Params (line 312) | struct Params {
method Params (line 346) | static Params
method CUTLASS_DEVICE (line 377) | CUTLASS_DEVICE bool
FILE: hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp
type flash (line 26) | namespace flash {
type CollectiveMainloopBwdSm90 (line 35) | struct CollectiveMainloopBwdSm90 {
type TensorStorage (line 280) | struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentP...
type Arguments (line 293) | struct Arguments {
type Params (line 326) | struct Params {
method Params (line 355) | static Params
method CUTLASS_DEVICE (line 413) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 422) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 561) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 579) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 594) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 669) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 683) | CUTLASS_DEVICE bool
FILE: hopper/mainloop_fwd_sm80.hpp
type flash (line 22) | namespace flash {
type CollectiveMainloopFwdSm80 (line 29) | struct CollectiveMainloopFwdSm80 {
type TensorStorageSharedQV (line 159) | struct TensorStorageSharedQV : cute::aligned_struct<128> {
type TensorStorageSeparateQV (line 167) | struct TensorStorageSeparateQV : cute::aligned_struct<128> {
type Arguments (line 176) | struct Arguments {
type Params (line 219) | struct Params {
method Params (line 264) | static Params
method CUTLASS_DEVICE (line 308) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 660) | CUTLASS_DEVICE bool
FILE: hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
type flash (line 27) | namespace flash {
type CollectiveMainloopFwdSm90 (line 34) | struct CollectiveMainloopFwdSm90 {
type TensorStorageWithoutPNoTranspose (line 308) | struct TensorStorageWithoutPNoTranspose : cute::aligned_struct<cute:...
type TensorStorageWithPNoTranspose (line 315) | struct TensorStorageWithPNoTranspose : cute::aligned_struct<cute::ma...
type TensorStorageWithPScaleNoTranspose (line 322) | struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct<cut...
type TensorStorageTransposeV (line 340) | struct TensorStorageTransposeV : cute::aligned_struct<cute::max(Smem...
type Arguments (line 359) | struct Arguments {
type Params (line 402) | struct Params {
method Params (line 458) | static Params
method CUTLASS_DEVICE (line 571) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 590) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 894) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 914) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 921) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 933) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 954) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 1355) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 1446) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 1550) | CUTLASS_DEVICE bool
FILE: hopper/mask.h
function thread_mma (line 52) | auto thread_mma = TiledMma{}
function thread0_mma (line 53) | auto thread0_mma = TiledMma{}
FILE: hopper/named_barrier.hpp
type flash (line 9) | namespace flash {
function CUTLASS_DEVICE (line 16) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 24) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 31) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 39) | CUTLASS_DEVICE
type FwdNamedBarriers (line 50) | enum class FwdNamedBarriers {
type BwdNamedBarriers (line 61) | enum class BwdNamedBarriers {
FILE: hopper/pack_gqa.h
function load_Q (line 78) | static void
FILE: hopper/padding.py
function unpad_input (line 8) | def unpad_input(hidden_states, attention_mask, unused_mask=None):
function pad_input (line 40) | def pad_input(hidden_states, indices, batch, seqlen):
FILE: hopper/paged_kv.h
function namespace (line 13) | namespace flash {
function CUTLASS_DEVICE (line 187) | CUTLASS_DEVICE
function CUTLASS_DEVICE (line 198) | CUTLASS_DEVICE
FILE: hopper/rotary.h
function namespace (line 11) | namespace flash {
function Tensor (line 451) | Tensor gK_cur_copy = [&] {
FILE: hopper/setup.py
function create_build_config_file (line 89) | def create_build_config_file():
function _write_ninja_file (line 125) | def _write_ninja_file(path,
function get_platform (line 310) | def get_platform():
function get_cuda_bare_metal_version (line 325) | def get_cuda_bare_metal_version(cuda_dir):
function check_if_cuda_home_none (line 334) | def check_if_cuda_home_none(global_option: str) -> None:
function check_env_flag (line 347) | def check_env_flag(name: str, default: str = "") -> bool:
function is_offline_build (line 352) | def is_offline_build() -> bool:
function get_flashattn_cache_path (line 369) | def get_flashattn_cache_path():
function open_url (line 378) | def open_url(url):
function download_and_copy (line 388) | def download_and_copy(name, src_func, dst_path, version, url_func):
function nvcc_threads_args (line 415) | def nvcc_threads_args():
function get_package_version (line 638) | def get_package_version():
function get_wheel_url (line 649) | def get_wheel_url():
class CachedWheelsCommand (line 672) | class CachedWheelsCommand(_bdist_wheel):
method run (line 680) | def run(self):
FILE: hopper/sm90_pipeline_no_cluster.hpp
class PipelineTmaAsyncNoCluster (line 22) | class PipelineTmaAsyncNoCluster: public Base {
method CUTLASS_DEVICE (line 33) | static
method if (line 62) | if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
FILE: hopper/softmax.h
function namespace (line 15) | namespace flash {
function CUTLASS_DEVICE (line 99) | CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_...
function online_softmax (line 127) | void online_softmax(Tensor0 &acc_s) {
FILE: hopper/test_attn_kvcache.py
function construct_local_mask (line 10) | def construct_local_mask(
function attention_ref (line 45) | def attention_ref(
function test_flash_attn_kvcache_nosplit (line 155) | def test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, ...
function test_flash_attn_kvcache_nosplit_fp8 (line 217) | def test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_reques...
function test_flash_attn_kvcache_output (line 292) | def test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, q...
function test_flash_attn_kvcache_output_fp8 (line 399) | def test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_request...
FILE: hopper/test_flash_attn.py
function should_test_backward (line 58) | def should_test_backward(args, kwargs):
function should_run_schema_check (line 80) | def should_run_schema_check(args, kwargs):
function should_run_fake_check (line 87) | def should_run_fake_check(args, kwargs):
function run_opcheck (line 93) | def run_opcheck(fn):
function test_flash_attn_output (line 167) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 404) | def test_flash_attn_varlen_output(
function test_flash_attn_kvcache (line 715) | def test_flash_attn_kvcache(
function _generate_block_kvcache (line 1054) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d...
function test_flash_attn_cluster (line 1090) | def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
function test_flash_attn_race_condition (line 1133) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
function attention_combine_ref (line 1167) | def attention_combine_ref(out_partial, lse_partial):
function test_flash_attn_combine (line 1190) | def test_flash_attn_combine(num_splits, seqlen, d, dtype):
function test_flash3_bw_compatibility (line 1225) | def test_flash3_bw_compatibility() -> None:
FILE: hopper/test_flash_attn_bwd_determinism.py
function test_flash_attn_output (line 110) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 391) | def test_flash_attn_varlen_output(
FILE: hopper/test_flash_attn_triton_amd.py
function test_flash_attn_output (line 105) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 334) | def test_flash_attn_varlen_output(
function test_flash_attn_kvcache (line 628) | def test_flash_attn_kvcache(
function _generate_block_kvcache (line 962) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d...
function test_flash_attn_cluster (line 998) | def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
function test_flash_attn_race_condition (line 1042) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
function attention_combine_ref (line 1076) | def attention_combine_ref(out_partial, lse_partial):
function test_flash_attn_combine (line 1099) | def test_flash_attn_combine(num_splits, seqlen, d, dtype):
function test_flash3_bw_compatibility (line 1135) | def test_flash3_bw_compatibility() -> None:
FILE: hopper/test_kvcache.py
function benchmark_fa_kv_old (line 20) | def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):
function benchmark_fa_kv (line 34) | def benchmark_fa_kv(fn, repeats=10, *args, **kwargs):
function main (line 47) | def main():
FILE: hopper/test_torch_compile_and_export.py
class EfficienctMultiHeadAttention (line 6) | class EfficienctMultiHeadAttention(nn.Module):
method __init__ (line 7) | def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=...
method forward (line 20) | def forward(self, x, attention_mask=None):
function create_model (line 39) | def create_model(batch_size=16, sequence_length=256, embedding_dim=2048,...
function test_export_model (line 45) | def test_export_model():
function test_compile_and_package_model (line 61) | def test_compile_and_package_model():
FILE: hopper/test_util.py
function generate_random_padding_mask (line 9) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 32) | def generate_qkv(
function construct_local_mask (line 157) | def construct_local_mask(
function construct_chunk_mask (line 193) | def construct_chunk_mask(
function attention_ref (line 226) | def attention_ref(
FILE: hopper/tile_scheduler.hpp
type TileSchedulerArguments (line 18) | struct TileSchedulerArguments {
class SingleTileScheduler (line 37) | class SingleTileScheduler {
type Params (line 44) | struct Params {
method Params (line 54) | static Params
method dim3 (line 65) | static dim3
type WorkTileInfo (line 70) | struct WorkTileInfo {
method CUTLASS_DEVICE (line 76) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 90) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 94) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 121) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 125) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 130) | CUTLASS_DEVICE
class StaticPersistentTileScheduler (line 141) | class StaticPersistentTileScheduler {
type Params (line 148) | struct Params {
method Params (line 154) | static Params
method dim3 (line 161) | static dim3
type WorkTileInfo (line 166) | struct WorkTileInfo {
method CUTLASS_DEVICE (line 169) | CUTLASS_DEVICE
method if (line 181) | if constexpr (Split) {
method CUTLASS_DEVICE (line 193) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 199) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 203) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 208) | CUTLASS_DEVICE
class DynamicPersistentTileScheduler (line 220) | class DynamicPersistentTileScheduler {
type Params (line 243) | struct Params {
method Params (line 252) | static Params
method dim3 (line 277) | static dim3
type WorkTileInfo (line 282) | struct WorkTileInfo {
method CUTLASS_DEVICE (line 285) | CUTLASS_DEVICE
method if (line 299) | if (bidhb < params.num_hb_quotient) {
method if (line 306) | if constexpr (Split) {
method CUTLASS_DEVICE (line 320) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 326) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 334) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 343) | CUTLASS_DEVICE
class SingleTileBwdLPTScheduler (line 368) | class SingleTileBwdLPTScheduler {
type Params (line 375) | struct Params {
method Params (line 386) | static Params
method dim3 (line 412) | static dim3
type WorkTileInfo (line 417) | struct WorkTileInfo {
method CUTLASS_DEVICE (line 422) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 436) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 440) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 472) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 476) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 481) | CUTLASS_DEVICE
class VarlenDynamicPersistentTileScheduler (line 493) | class VarlenDynamicPersistentTileScheduler {
type Params (line 507) | struct Params {
method Params (line 524) | static Params
method dim3 (line 547) | static dim3
type WorkTileInfo (line 552) | struct WorkTileInfo {
method CUTLASS_DEVICE (line 555) | CUTLASS_DEVICE
method if (line 572) | if constexpr (!Split) {
method else (line 574) | else {
method tile_idx_to_work_tile (line 597) | tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkT...
method CUTLASS_DEVICE (line 761) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 776) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 782) | CUTLASS_DEVICE
method CUTLASS_DEVICE (line 791) | CUTLASS_DEVICE
FILE: hopper/tile_size.h
function else (line 31) | else if (headdim <= 128) {
function else (line 36) | else if (headdim <= 192) {
FILE: hopper/utils.h
function namespace (line 26) | namespace flash {
function CUTLASS_DEVICE (line 675) | CUTLASS_DEVICE
FILE: setup.py
function cuda_archs (line 72) | def cuda_archs() -> str:
function get_platform (line 76) | def get_platform():
function get_cuda_bare_metal_version (line 91) | def get_cuda_bare_metal_version(cuda_dir):
function add_cuda_gencodes (line 100) | def add_cuda_gencodes(cc_flag, archs, bare_metal_version):
function get_hip_version (line 153) | def get_hip_version():
function check_if_cuda_home_none (line 157) | def check_if_cuda_home_none(global_option: str) -> None:
function check_if_rocm_home_none (line 169) | def check_if_rocm_home_none(global_option: str) -> None:
function detect_hipify_v2 (line 179) | def detect_hipify_v2():
function append_nvcc_threads (line 191) | def append_nvcc_threads(nvcc_extra_args):
function rename_cpp_to_cu (line 195) | def rename_cpp_to_cu(cpp_files):
function validate_and_update_archs (line 200) | def validate_and_update_archs(archs):
function get_package_version (line 508) | def get_package_version():
function get_wheel_url (line 519) | def get_wheel_url():
class CachedWheelsCommand (line 550) | class CachedWheelsCommand(_bdist_wheel):
method run (line 558) | def run(self):
class NinjaBuildExtension (line 585) | class NinjaBuildExtension(BuildExtension):
method __init__ (line 586) | def __init__(self, *args, **kwargs) -> None:
FILE: tests/cute/benchmark_block_sparsity.py
class BenchmarkConfig (line 32) | class BenchmarkConfig:
class BenchmarkResult (line 47) | class BenchmarkResult:
function benchmark_pytorch_block_sparsity (line 56) | def benchmark_pytorch_block_sparsity(
function benchmark_cute_block_sparsity (line 91) | def benchmark_cute_block_sparsity(
function run_benchmark (line 195) | def run_benchmark(
function generate_configs (line 220) | def generate_configs(
function print_results (line 243) | def print_results(results: List[BenchmarkResult]):
function main (line 296) | def main():
FILE: tests/cute/benchmark_mask_mod.py
class BenchmarkConfig (line 30) | class BenchmarkConfig:
class FlashAttentionBenchmark (line 83) | class FlashAttentionBenchmark:
method __init__ (line 84) | def __init__(self, config: BenchmarkConfig):
method _validate_config (line 112) | def _validate_config(self):
method _generate_varlen_seqlens (line 139) | def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tupl...
method _create_tensors (line 154) | def _create_tensors(self) -> Dict[str, torch.Tensor]:
method _compile_kernel (line 307) | def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[A...
method _calculate_flops (line 448) | def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float:
method benchmark (line 525) | def benchmark(self) -> Dict[str, Any]:
method _print_results (line 604) | def _print_results(self, results: Dict[str, Any]):
FILE: tests/cute/conftest.py
function _get_gpu_ids (line 11) | def _get_gpu_ids():
function pytest_configure (line 32) | def pytest_configure(config):
function pytest_collection_finish (line 61) | def pytest_collection_finish(session):
FILE: tests/cute/mask_mod_definitions.py
function cute_causal_mask (line 26) | def cute_causal_mask(
function get_cute_causal_mask (line 39) | def get_cute_causal_mask(offset: int):
function get_cute_block_causal_mask (line 43) | def get_cute_block_causal_mask(offset: int):
function get_cute_sliding_window_mask (line 60) | def get_cute_sliding_window_mask(window_left: int, window_right: int, of...
function cute_block_diagonal_mask (line 85) | def cute_block_diagonal_mask(
function cute_mini_causal_mask (line 98) | def cute_mini_causal_mask(
function cute_prefix_lm_mask (line 114) | def cute_prefix_lm_mask(
function cute_dilated_sliding_window_mask (line 130) | def cute_dilated_sliding_window_mask(
function cute_document_mask (line 148) | def cute_document_mask(
function cute_ima_mask (line 164) | def cute_ima_mask(
function get_flex_causal_mask (line 191) | def get_flex_causal_mask(offset: int):
function get_flex_block_causal_mask (line 198) | def get_flex_block_causal_mask(offset: int):
function get_flex_sliding_window_mask (line 205) | def get_flex_sliding_window_mask(window_left: int, window_right: int, of...
function flex_block_diagonal_mask (line 215) | def flex_block_diagonal_mask(b, h, q_idx, kv_idx):
function flex_mini_causal_mask (line 220) | def flex_mini_causal_mask(b, h, q_idx, kv_idx):
function flex_prefix_lm_mask (line 224) | def flex_prefix_lm_mask(b, h, q_idx, kv_idx):
function flex_dilated_sliding_window_mask (line 232) | def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx):
function flex_document_mask (line 241) | def flex_document_mask(b, h, q_idx, kv_idx, doc_id):
function flex_ima_mask (line 245) | def flex_ima_mask(b, h, q_idx, kv_idx, bias):
function random_doc_id_tensor (line 254) | def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"):
function get_mask_pair (line 298) | def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=N...
FILE: tests/cute/score_mod_definitions.py
function score_mod_identity (line 14) | def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_inf...
function score_mod_identity_vectorized (line 19) | def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx,...
function score_mod_causal (line 24) | def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info,...
function score_mod_causal_vectorized (line 30) | def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, s...
function score_mod_rel_bias (line 41) | def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_inf...
function score_mod_rel_bias_vectorized (line 48) | def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx,...
function score_mod_rel_bias_x2 (line 60) | def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_...
function score_mod_rel_bias_x2_vectorized (line 68) | def score_mod_rel_bias_x2_vectorized(
function score_mod_times_two (line 82) | def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_in...
function score_mod_alibi (line 88) | def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, ...
function score_mod_alibi_vectorized (line 100) | def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, se...
function score_mod_sliding_window (line 116) | def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seql...
function score_mod_block_diagonal (line 124) | def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seql...
function score_mod_causal_v2 (line 132) | def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_in...
function score_mod_batch_bias (line 139) | def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_i...
function score_mod_batch_bias_vectorized (line 150) | def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_id...
function score_mod_dual_buffer (line 161) | def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_...
function score_mod_dual_buffer_vectorized (line 181) | def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_i...
function score_mod_global_kv_bias (line 205) | def score_mod_global_kv_bias(
function score_mod_global_q_bias (line 222) | def score_mod_global_q_bias(
function score_mod_global_rel_plus_kv_bias (line 238) | def score_mod_global_rel_plus_kv_bias(
function score_mod_global_q_and_kv_bias (line 260) | def score_mod_global_q_and_kv_bias(
function score_mod_global_logical_rel_plus_kv_bias (line 290) | def score_mod_global_logical_rel_plus_kv_bias(
function score_mod_stress_complex_arithmetic (line 314) | def score_mod_stress_complex_arithmetic(
function score_mod_stress_conditional_mask (line 342) | def score_mod_stress_conditional_mask(
function score_mod_stress_multi_buffer (line 372) | def score_mod_stress_multi_buffer(
function score_mod_stress_global_offset (line 431) | def score_mod_stress_global_offset(
function score_mod_stress_xor_pattern (line 449) | def score_mod_stress_xor_pattern(
function score_mod_debug_global_idx (line 475) | def score_mod_debug_global_idx(
function identity_eager (line 490) | def identity_eager(score, b, h, q_idx, kv_idx):
function causal_eager (line 494) | def causal_eager(score, b, h, q_idx, kv_idx):
function rel_bias_eager (line 498) | def rel_bias_eager(score, b, h, q_idx, kv_idx):
function rel_bias_x2_eager (line 502) | def rel_bias_x2_eager(score, b, h, q_idx, kv_idx):
function times_two_eager (line 506) | def times_two_eager(score, b, h, q_idx, kv_idx):
function alibi_eager (line 510) | def alibi_eager(score, b, h, q_idx, kv_idx):
function sliding_window_eager (line 515) | def sliding_window_eager(score, b, h, q_idx, kv_idx):
function block_diagonal_eager (line 519) | def block_diagonal_eager(score, b, h, q_idx, kv_idx):
function causal_v2_eager (line 523) | def causal_v2_eager(score, b, h, q_idx, kv_idx):
function batch_bias_factory (line 527) | def batch_bias_factory(bias_tensor):
function dual_buffer_factory (line 534) | def dual_buffer_factory(head_bias, pos_bias):
function packed_kv_bias_factory (line 541) | def packed_kv_bias_factory(bias_tensor, cu_seqlens_k):
function packed_q_bias_factory (line 554) | def packed_q_bias_factory(bias_tensor, cu_seqlens_q):
function packed_rel_plus_kv_bias_factory (line 566) | def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):
function packed_q_and_kv_bias_factory (line 580) | def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqle...
function packed_logical_rel_plus_kv_bias_factory (line 597) | def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):
function stress_complex_arithmetic_factory (line 605) | def stress_complex_arithmetic_factory(bias, cu_seqlens_q):
function stress_conditional_mask_factory (line 618) | def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens...
function stress_multi_buffer_factory (line 632) | def stress_multi_buffer_factory(
function stress_global_offset_factory (line 654) | def stress_global_offset_factory(token_bias, cu_seqlens_k):
function stress_xor_pattern_factory (line 661) | def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k):
function debug_global_idx_factory (line 670) | def debug_global_idx_factory(bias, cu_seqlens_k):
FILE: tests/cute/test_block_sparsity.py
function _call_compute_block_sparsity (line 11) | def _call_compute_block_sparsity(
function _compare_block_sparsity (line 43) | def _compare_block_sparsity(
function test_fixed_length_masks (line 213) | def test_fixed_length_masks(
function test_parameterized_masks (line 292) | def test_parameterized_masks(
function test_edge_cases (line 364) | def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n):
function test_fast_sampling (line 426) | def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_...
FILE: tests/cute/test_flash_attn.py
function test_flash_attn_output (line 100) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 453) | def test_flash_attn_varlen_output(
function test_flash_attn_kvcache (line 921) | def test_flash_attn_kvcache(
function test_flash_attn_bwd_preallocated_outputs (line 1430) | def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, caus...
function test_flash_attn_lse_grad (line 1469) | def test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype):
function test_flash_attn_lse_grad_unused (line 1549) | def test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype):
function _generate_block_kvcache (line 1599) | def _generate_block_kvcache(
function test_flash_attn_paged_deepseek (line 1634) | def test_flash_attn_paged_deepseek(seqlen_q, page_size):
function test_flash_attn_invalid_head_dim (line 1684) | def test_flash_attn_invalid_head_dim(head_dim):
FILE: tests/cute/test_flash_attn_combine.py
function attention_combine_ref (line 19) | def attention_combine_ref(out_partial, lse_partial):
function check_combine_results (line 33) | def check_combine_results(out, lse, out_ref, lse_ref, dtype):
function test_flash_attn_combine (line 58) | def test_flash_attn_combine(num_splits, seqlen, d, dtype):
function test_flash_attn_combine_varlen (line 115) | def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, d...
function test_flash_attn_combine_varlen_batch_idx (line 231) | def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype):
FILE: tests/cute/test_flash_attn_fast.py
function test_flash_attn_output (line 49) | def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mh...
function test_flash_attn_varlen_output (line 116) | def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype):
function test_flash_attn_varlen_unpad_output (line 189) | def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unp...
function attention_combine_ref (line 287) | def attention_combine_ref(out_partial, lse_partial):
function test_flash_attn_combine (line 300) | def test_flash_attn_combine(num_splits, seqlen, d, dtype):
FILE: tests/cute/test_flash_attn_race_condition.py
function test_flash_attn_output (line 63) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 393) | def test_flash_attn_varlen_output(
FILE: tests/cute/test_flash_attn_varlen.py
function test_varlen (line 17) | def test_varlen(
function check_varlen_vs_torch_flash (line 51) | def check_varlen_vs_torch_flash(
function generate_varlen_args (line 147) | def generate_varlen_args(
function torch_flash_ref (line 192) | def torch_flash_ref(
function _stats (line 296) | def _stats(name, a, b, atol, rtol):
FILE: tests/cute/test_mask_mod.py
function reset_torch_state (line 38) | def reset_torch_state():
function create_tensors (line 48) | def create_tensors(
function compute_reference_flex_attn (line 73) | def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tupl...
function get_coarse_block_mask_pair (line 111) | def get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_blo...
function _run_mask_test (line 171) | def _run_mask_test(
function test_mask_mod_ima_partial_block (line 480) | def test_mask_mod_ima_partial_block():
function test_q_boundary_masking_block_sparse_bwd (line 514) | def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_na...
function test_single_doc_bwd_minimal (line 549) | def test_single_doc_bwd_minimal():
function test_static_masks (line 677) | def test_static_masks(
function test_parameterized_masks (line 725) | def test_parameterized_masks(
function test_sm100_block_sparse_sink_all_masked (line 758) | def test_sm100_block_sparse_sink_all_masked():
function test_sm100_block_sparse_q_stage1 (line 804) | def test_sm100_block_sparse_q_stage1():
function test_sm100_block_sparse_coarse_blocks (line 846) | def test_sm100_block_sparse_coarse_blocks():
function test_sm100_block_sparse_coarse_blocks_mismatch (line 943) | def test_sm100_block_sparse_coarse_blocks_mismatch():
function run_cute_mask_bwd (line 1053) | def run_cute_mask_bwd(
function run_flex_reference_bwd (line 1090) | def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None):
function test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message (line 1127) | def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_mess...
function test_gqa_block_sparse_broadcast_pattern_recompilation (line 1199) | def test_gqa_block_sparse_broadcast_pattern_recompilation():
function test_gqa_expand_stride_zero_bug (line 1301) | def test_gqa_expand_stride_zero_bug():
function test_persistent_blocksparse_empty_tiles (line 1416) | def test_persistent_blocksparse_empty_tiles():
FILE: tests/cute/test_score_mod.py
function create_tensors (line 113) | def create_tensors(
function run_cute_flash (line 122) | def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=F...
function run_flex_reference (line 139) | def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Te...
function test_cute_vs_flex_attention (line 149) | def test_cute_vs_flex_attention(
function test_cute_score_mod_vectorized (line 202) | def test_cute_score_mod_vectorized(
function test_cute_vs_flex_attention_with_aux_tensors (line 235) | def test_cute_vs_flex_attention_with_aux_tensors(
function test_cute_score_mod_with_aux_tensors_vectorized (line 306) | def test_cute_score_mod_with_aux_tensors_vectorized(
function _generate_block_kvcache (line 354) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d...
function test_score_mod_with_paged_kvcache (line 396) | def test_score_mod_with_paged_kvcache(
function test_score_mod_with_paged_kvcache_aux_tensors (line 545) | def test_score_mod_with_paged_kvcache_aux_tensors(
function score_mod_bwd_5 (line 694) | def score_mod_bwd_5(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_inf...
function score_mod_bwd_3 (line 700) | def score_mod_bwd_3(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_inf...
function score_mod_bwd_identity (line 706) | def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seq...
function score_mod_bwd_causal (line 711) | def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqle...
function score_mod_squared (line 721) | def score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info...
function score_mod_bwd_squared (line 727) | def score_mod_bwd_squared(grad, score, b_idx, h_idx, q_idx, kv_idx, seql...
function score_squared_eager (line 732) | def score_squared_eager(score, b, h, q_idx, kv_idx):
function run_cute_flash_bwd (line 754) | def run_cute_flash_bwd(
function run_flex_reference_bwd (line 796) | def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None):
function test_cute_vs_flex_attention_backward (line 832) | def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype...
function make_aux_tensors_for_bwd (line 881) | def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, nu...
function test_cute_vs_flex_attention_backward_with_aux (line 901) | def test_cute_vs_flex_attention_backward_with_aux(
function test_cute_vs_flex_attention_backward_pack_gqa (line 962) | def test_cute_vs_flex_attention_backward_pack_gqa(
FILE: tests/cute/test_score_mod_varlen.py
function run_cute_flash (line 183) | def run_cute_flash(
function run_flex_varlen_ref (line 232) | def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, ...
function setup_tensors (line 283) | def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, h...
function prepare_ref_tensors (line 324) | def prepare_ref_tensors(
function check_results (line 346) | def check_results(
function test_varlen_with_score_mod (line 412) | def test_varlen_with_score_mod(
function test_varlen_with_score_mod_vectorized (line 521) | def test_varlen_with_score_mod_vectorized(
function test_varlen_with_global_idx_score_mod (line 602) | def test_varlen_with_global_idx_score_mod(
function test_varlen_score_mod_kvcache (line 791) | def test_varlen_score_mod_kvcache(
function test_varlen_score_mod_with_paged_kvcache_global (line 950) | def test_varlen_score_mod_with_paged_kvcache_global(
FILE: tests/cute/test_utils.py
class TestHashCallable (line 9) | class TestHashCallable:
method test_returns_cute_hash_when_set_on_function (line 12) | def test_returns_cute_hash_when_set_on_function(self):
method test_returns_cute_hash_from_wrapped_function (line 23) | def test_returns_cute_hash_from_wrapped_function(self):
method test_prefers_wrapper_cute_hash_over_wrapped (line 39) | def test_prefers_wrapper_cute_hash_over_wrapped(self):
method test_fallback_to_source_hashing (line 56) | def test_fallback_to_source_hashing(self):
method test_same_function_produces_same_hash (line 67) | def test_same_function_produces_same_hash(self):
method test_different_functions_produce_different_hashes (line 77) | def test_different_functions_produce_different_hashes(self):
method test_fast_path_skips_expensive_hashing (line 90) | def test_fast_path_skips_expensive_hashing(self):
method test_fast_path_on_wrapped_skips_expensive_hashing (line 125) | def test_fast_path_on_wrapped_skips_expensive_hashing(self):
method test_closure_values_affect_hash (line 163) | def test_closure_values_affect_hash(self):
class TestHashCallableIntegration (line 182) | class TestHashCallableIntegration:
method test_repeated_calls_use_cached_hash (line 185) | def test_repeated_calls_use_cached_hash(self):
FILE: tests/layers/test_rotary.py
function test_rotary (line 21) | def test_rotary(rotary_emb_fraction, seqlen_offset):
function test_rotary_interleaved (line 95) | def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
FILE: tests/losses/test_cross_entropy.py
function test_cross_entropy_loss (line 28) | def test_cross_entropy_loss(
FILE: tests/losses/test_cross_entropy_parallel.py
function test_cross_entropy_loss_parallel (line 32) | def test_cross_entropy_loss_parallel(
FILE: tests/models/test_baichuan.py
function test_baichuan_state_dict (line 36) | def test_baichuan_state_dict(model_name):
function test_baichuan_optimized (line 60) | def test_baichuan_optimized(model_name):
function test_baichuan_parallel_forward (line 144) | def test_baichuan_parallel_forward(model_name, world_size):
function test_baichuan_generation (line 233) | def test_baichuan_generation(model_name):
function test_baichuan_parallel_generation (line 345) | def test_baichuan_parallel_generation(model_name, world_size):
FILE: tests/models/test_bert.py
function test_bert_state_dict (line 23) | def test_bert_state_dict(model_name):
function get_hf_models (line 33) | def get_hf_models(model_name, config, dtype):
function test_bert_non_optimized (line 53) | def test_bert_non_optimized(model_name):
function test_bert_optimized (line 100) | def test_bert_optimized(model_name):
function test_bert_dense_seq_output (line 207) | def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_la...
function test_inv_remap_state_dict (line 309) | def test_inv_remap_state_dict(model_name: str):
FILE: tests/models/test_bigcode.py
function test_bigcode_state_dict (line 15) | def test_bigcode_state_dict(model_name):
function test_bigcode_optimized (line 28) | def test_bigcode_optimized(model_name):
function test_bigcode_generation (line 88) | def test_bigcode_generation(model_name):
function test_inv_remap_state_dict (line 189) | def test_inv_remap_state_dict(model_name: str):
FILE: tests/models/test_btlm.py
function test_btlm_state_dict (line 16) | def test_btlm_state_dict(model_name):
function test_btlm_optimized (line 30) | def test_btlm_optimized(model_name):
function test_btlm_generation (line 100) | def test_btlm_generation(model_name):
function test_btlm_init (line 206) | def test_btlm_init(model_name):
FILE: tests/models/test_falcon.py
function test_falcon_state_dict (line 21) | def test_falcon_state_dict(model_name):
function test_falcon_optimized (line 36) | def test_falcon_optimized(model_name):
function test_falcon_parallel_forward (line 104) | def test_falcon_parallel_forward(model_name, world_size):
function test_falcon_generation (line 186) | def test_falcon_generation(model_name):
function test_falcon_parallel_generation (line 294) | def test_falcon_parallel_generation(model_name, world_size):
FILE: tests/models/test_gpt.py
function test_gpt2_state_dict (line 20) | def test_gpt2_state_dict(model_name):
function test_gpt2_non_optimized (line 32) | def test_gpt2_non_optimized(model_name):
function test_gpt2_optimized (line 82) | def test_gpt2_optimized(model_name):
function test_gpt2_generation (line 142) | def test_gpt2_generation(model_name, rotary, optimized):
function get_logits (line 264) | def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwa...
function test_gpt2_generation_cg (line 282) | def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
function test_gpt2_multiple_token_generation (line 345) | def test_gpt2_multiple_token_generation(model_name, optimized):
function test_gpt2_speculative_decoding (line 391) | def test_gpt2_speculative_decoding(model_name, optimized, cg):
function test_gpt2_shard_unshard (line 460) | def test_gpt2_shard_unshard(n_heads_q_kv):
FILE: tests/models/test_gpt_generation_parallel.py
function test_tensor_parallel (line 21) | def test_tensor_parallel(model_name, rotary, world_size):
FILE: tests/models/test_gpt_neox.py
function test_gptj_state_dict (line 15) | def test_gptj_state_dict(model_name):
function test_gpt_neox_optimized (line 36) | def test_gpt_neox_optimized(model_name):
FILE: tests/models/test_gpt_parallel.py
function test_gpt_parallel (line 29) | def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, d...
FILE: tests/models/test_gptj.py
function test_gptj_state_dict (line 16) | def test_gptj_state_dict(model_name):
function test_gptj_optimized (line 27) | def test_gptj_optimized(model_name):
function test_gptj_generation (line 87) | def test_gptj_generation(model_name):
FILE: tests/models/test_llama.py
function _pretrained_state_dict_from_checkpoint (line 36) | def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, ...
function test_llama_state_dict (line 50) | def test_llama_state_dict(model_name):
function test_inv_remap_state_dict_hf_llama (line 68) | def test_inv_remap_state_dict_hf_llama(model_name):
function test_llama_optimized (line 95) | def test_llama_optimized(model_name):
function test_llama_parallel (line 186) | def test_llama_parallel(model_name, world_size):
function test_llama_generation (line 289) | def test_llama_generation(model_name, checkpoint_format):
function test_llama_parallel_generation (line 402) | def test_llama_parallel_generation(model_name, world_size):
function test_llama_parallel_uneven_num_heads (line 537) | def test_llama_parallel_uneven_num_heads(world_size):
FILE: tests/models/test_opt.py
function test_opt_state_dict (line 19) | def test_opt_state_dict(model_name):
function test_opt_optimized (line 33) | def test_opt_optimized(model_name):
function test_opt_generation (line 100) | def test_opt_generation(model_name):
FILE: tests/models/test_vit.py
function test_vit (line 13) | def test_vit(optimized, fused_mlp):
FILE: tests/modules/test_block_parallel.py
function test_block_parallel (line 28) | def test_block_parallel(dim, sequence_parallel, world_size, dtype):
FILE: tests/modules/test_embedding_parallel.py
function test_embedding_parallel (line 24) | def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_s...
FILE: tests/modules/test_mha_parallel.py
function test_mha_parallel (line 26) | def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size...
FILE: tests/modules/test_mlp_parallel.py
function test_mlp_parallel (line 24) | def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dt...
FILE: tests/ops/test_dropout_layer_norm.py
function test_dropout_layer_norm_training (line 52) | def test_dropout_layer_norm_training(
function test_dropout_layer_norm_eval (line 177) | def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtyp...
function test_dropout_layer_norm_prenorm_training (line 239) | def test_dropout_layer_norm_prenorm_training(
function test_dropout_layer_norm_prenorm_eval (line 371) | def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, resid...
function test_dropout_layer_norm_subset_training (line 435) | def test_dropout_layer_norm_subset_training(
function test_dropout_layer_norm_subset_prenorm_training (line 592) | def test_dropout_layer_norm_subset_prenorm_training(
function test_dropout_layer_norm_parallel_residual_training (line 762) | def test_dropout_layer_norm_parallel_residual_training(
function test_dropout_layer_norm_parallel_residual_prenorm_training (line 971) | def test_dropout_layer_norm_parallel_residual_prenorm_training(
function test_dropout_layer_norm_randomness (line 1161) | def test_dropout_layer_norm_randomness():
FILE: tests/ops/test_fused_dense.py
function test_fused_linear_bias (line 16) | def test_fused_linear_bias(in_features, out_features, has_bias, return_r...
function test_fused_mlp (line 92) | def test_fused_mlp(
FILE: tests/ops/test_fused_dense_parallel.py
function test_fused_linear_bias (line 25) | def test_fused_linear_bias(
function test_fused_mlp (line 124) | def test_fused_mlp(in_features, out_features, has_bias2, sequence_parall...
FILE: tests/ops/triton/test_layer_norm.py
function test_layer_norm (line 47) | def test_layer_norm(
function test_layer_norm_linear (line 263) | def test_layer_norm_linear(
FILE: tests/test_flash_attn.py
function attn_bias_from_alibi_slopes (line 29) | def attn_bias_from_alibi_slopes(
function generate_random_padding_mask (line 58) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 74) | def generate_qkv(
function construct_local_mask (line 182) | def construct_local_mask(
function attention_ref (line 217) | def attention_ref(
function attention_kvpacked_ref (line 307) | def attention_kvpacked_ref(
function attention_qkvpacked_ref (line 340) | def attention_qkvpacked_ref(
function generate_sparsity_mask (line 369) | def generate_sparsity_mask(seqlen, sparsity=0.3):
function attention_blocksparse_ref (line 382) | def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, drop...
function convert_flash_attn_S_to_softmax (line 411) | def convert_flash_attn_S_to_softmax(
function normalize_flash_attn_S (line 465) | def normalize_flash_attn_S(
function get_dropout_fraction (line 529) | def get_dropout_fraction(
function test_flash_attn_qkvpacked (line 586) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi...
function test_flash_attn_varlen_qkvpacked (line 733) | def test_flash_attn_varlen_qkvpacked(
function test_flash_attn_output (line 903) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 1172) | def test_flash_attn_varlen_output(
function test_flash_attn_causal (line 1482) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty...
function test_flash_attn_varlen_causal (line 1593) | def test_flash_attn_varlen_causal(
function test_flash_attn_splitkv (line 1765) | def test_flash_attn_splitkv(
function test_flash_attn_kvcache (line 1907) | def test_flash_attn_kvcache(
function _generate_block_kvcache (line 2143) | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, n...
function test_flash_attn_race_condition (line 2199) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau...
function test_flash_attn_bwd_overflow (line 2247) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
function test_flash_attn_bwd_transpose (line 2303) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
function test_flash_attn_bwd_varlen_overflow (line 2355) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
function test_flash_attn_deterministic (line 2413) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau...
function test_flash_attn_varlen_deterministic (line 2471) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,...
FILE: tests/test_flash_attn_ck.py
function is_bwd_hdim_supported (line 30) | def is_bwd_hdim_supported(d):
function ck_randval_to_dropout_mask (line 34) | def ck_randval_to_dropout_mask(randval, p):
function pad_rearrange_dropout_mask_hts_to_bhss (line 41) | def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen...
function test_flash_attn_qkvpacked (line 73) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi...
function test_flash_attn_varlen_qkvpacked (line 171) | def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local...
function test_flash_attn_output (line 305) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 522) | def test_flash_attn_varlen_output(
function test_flash_attn_causal (line 780) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty...
function test_flash_attn_varlen_causal (line 880) | def test_flash_attn_varlen_causal(
function test_flash_attn_kvcache (line 1053) | def test_flash_attn_kvcache(
function test_flash_attn_race_condition (line 1314) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau...
function test_flash_attn_bwd_overflow (line 1360) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
function test_flash_attn_bwd_transpose (line 1417) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
function test_flash_attn_bwd_varlen_overflow (line 1467) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
function test_flash_attn_deterministic (line 1515) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau...
function test_flash_attn_varlen_deterministic (line 1563) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,...
FILE: tests/test_flash_attn_triton_amd.py
function _get_block_size_n_triton (line 22) | def _get_block_size_n_triton(device, head_dim, is_dropout, is_causal):
function attn_bias_from_alibi_slopes (line 44) | def attn_bias_from_alibi_slopes(
function generate_random_padding_mask (line 73) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 89) | def generate_qkv(
function construct_local_mask (line 197) | def construct_local_mask(
function attention_ref (line 232) | def attention_ref(
function attention_kvpacked_ref (line 322) | def attention_kvpacked_ref(
function attention_qkvpacked_ref (line 355) | def attention_qkvpacked_ref(
function generate_sparsity_mask (line 384) | def generate_sparsity_mask(seqlen, sparsity=0.3):
function attention_blocksparse_ref (line 397) | def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, drop...
function convert_flash_attn_S_to_softmax (line 426) | def convert_flash_attn_S_to_softmax(
function normalize_flash_attn_S (line 480) | def normalize_flash_attn_S(
function get_dropout_fraction (line 544) | def get_dropout_fraction(
function test_flash_attn_qkvpacked (line 601) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi...
function test_flash_attn_varlen_qkvpacked (line 748) | def test_flash_attn_varlen_qkvpacked(
function test_flash_attn_output (line 918) | def test_flash_attn_output(
function test_flash_attn_varlen_output (line 1191) | def test_flash_attn_varlen_output(
function test_flash_attn_causal (line 1504) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty...
function test_flash_attn_varlen_causal (line 1619) | def test_flash_attn_varlen_causal(
function test_flash_attn_splitkv (line 1792) | def test_flash_attn_splitkv(
function test_flash_attn_kvcache (line 1937) | def test_flash_attn_kvcache(
function _generate_block_kvcache (line 2173) | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, n...
function test_flash_attn_race_condition (line 2230) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau...
function test_flash_attn_bwd_overflow (line 2279) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
function test_flash_attn_bwd_transpose (line 2336) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
function test_flash_attn_bwd_varlen_overflow (line 2389) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
function test_flash_attn_deterministic (line 2448) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau...
function test_flash_attn_varlen_deterministic (line 2507) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,...
FILE: tests/test_rotary.py
function generate_cos_sin (line 18) | def generate_cos_sin(seqlen, rotary_dim, device, dtype):
function generate_seqlen_offsets (line 26) | def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, dev...
function index_cos_sin (line 35) | def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
function test_rotary_emb_func (line 60) | def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_o...
function test_rotary_emb_qkv (line 113) | def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_typ...
function test_rotary_emb_kv (line 181) | def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type...
function test_rotary_emb_varlen_func (line 229) | def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, s...
function test_compilation_count (line 281) | def test_compilation_count():
FILE: tests/test_util.py
function generate_random_padding_mask (line 8) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r...
function generate_qkv (line 31) | def generate_qkv(
function construct_local_mask (line 150) | def construct_local_mask(
function attention_ref (line 185) | def attention_ref(
FILE: tools/sass_diff.py
class Line (line 35) | class Line:
function _normalize_instr (line 44) | def _normalize_instr(text: str) -> str:
function parse_sass (line 68) | def parse_sass(path: str) -> list[Line]:
class DiffBlock (line 95) | class DiffBlock:
function diff_sass (line 101) | def diff_sass(a_lines: list[Line], b_lines: list[Line]) -> list[DiffBlock]:
function _fmt (line 120) | def _fmt(line: Line, prefix: str, color: str, use_color: bool, show_norm...
function print_diff (line 128) | def print_diff(blocks: list[DiffBlock], context: int = 3,
function _get_opcode (line 174) | def _get_opcode(raw: str) -> str | None:
function print_summary (line 183) | def print_summary(a_all: list[Line], b_all: list[Line], blocks: list[Dif...
function main (line 221) | def main():
FILE: training/run.py
function dictconfig_filter_key (line 23) | def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig:
function main (line 34) | def main(config: DictConfig):
FILE: training/src/callbacks/causality_monitor.py
class CausalityMonitor (line 9) | class CausalityMonitor(Callback):
method __init__ (line 26) | def __init__(self, seq_len: int = 10, input_dim: int = 0):
method on_train_epoch_end (line 32) | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lig...
FILE: training/src/callbacks/ema.py
class EMACallback (line 16) | class EMACallback(Callback):
method __init__ (line 19) | def __init__(self, decay: float, use_num_updates: bool = True):
method on_train_start (line 30) | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightni...
method on_train_batch_end (line 40) | def on_train_batch_end(
method on_validation_start (line 51) | def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.Li...
method on_validation_end (line 57) | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.Ligh...
method on_test_start (line 61) | def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin...
method on_test_end (line 66) | def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningM...
method on_save_checkpoint (line 70) | def on_save_checkpoint(
method on_load_checkpoint (line 75) | def on_load_checkpoint(
FILE: training/src/callbacks/flop_count.py
class FlopCount (line 14) | class FlopCount(Callback):
method __init__ (line 17) | def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'],
method on_fit_start (line 34) | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -...
FILE: training/src/callbacks/gpu_affinity.py
function l2_promote (line 10) | def l2_promote():
function set_affinity (line 21) | def set_affinity(trainer):
class GpuAffinity (line 34) | class GpuAffinity(Callback):
method setup (line 39) | def setup(self, trainer: Trainer, pl_module: LightningModule, stage=No...
FILE: training/src/callbacks/loss_scale_monitor.py
class LossScaleMonitor (line 9) | class LossScaleMonitor(Callback):
method on_before_optimizer_step (line 17) | def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwa...
FILE: training/src/callbacks/model_checkpoint.py
class ModelCheckpointMine (line 8) | class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint):
method __init__ (line 10) | def __init__(self, *args, fault_tolerant=False, **kwargs):
method on_exception (line 14) | def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> N...
FILE: training/src/callbacks/norm_monitor.py
class NormMonitor (line 22) | class NormMonitor(Callback):
method __init__ (line 26) | def __init__(self, layer_norm_only: bool = False):
method on_before_optimizer_step (line 33) | def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args:...
FILE: training/src/callbacks/params_log.py
class ParamsLog (line 8) | class ParamsLog(Callback):
method __init__ (line 11) | def __init__(self, total_params_log: bool = True, trainable_params_log...
method on_fit_start (line 23) | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -...
FILE: training/src/callbacks/speed_monitor.py
class SpeedMonitor (line 12) | class SpeedMonitor(Callback):
method __init__ (line 15) | def __init__(self, intra_step_time: bool = True, inter_step_time: bool...
method on_train_start (line 27) | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightni...
method on_train_epoch_start (line 30) | def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.L...
method on_validation_epoch_start (line 35) | def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: ...
method on_test_epoch_start (line 38) | def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Li...
method on_train_batch_start (line 42) | def on_train_batch_start(
method on_train_batch_end (line 64) | def on_train_batch_end(
method on_train_epoch_end (line 89) | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lig...
FILE: training/src/callbacks/wandb_callbacks.py
function get_wandb_logger (line 16) | def get_wandb_logger(trainer: Trainer) -> WandbLogger:
class WatchModel (line 37) | class WatchModel(Callback):
method __init__ (line 40) | def __init__(self, log: str = "gradients", log_freq: int = 100):
method on_train_start (line 45) | def on_train_start(self, trainer, pl_module):
class UploadCodeAsArtifact (line 50) | class UploadCodeAsArtifact(Callback):
method __init__ (line 53) | def __init__(self, code_dir: str, use_git: bool = True):
method on_train_start (line 65) | def on_train_start(self, trainer, pl_module):
class UploadCheckpointsAsArtifact (line 97) | class UploadCheckpointsAsArtifact(Callback):
method __init__ (line 100) | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: b...
method on_keyboard_interrupt (line 105) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_train_end (line 109) | def on_train_end(self, trainer, pl_module):
class LogConfusionMatrix (line 124) | class LogConfusionMatrix(Callback):
method __init__ (line 129) | def __init__(self):
method on_sanity_check_start (line 134) | def on_sanity_check_start(self, trainer, pl_module) -> None:
method on_sanity_check_end (line 137) | def on_sanity_check_end(self, trainer, pl_module):
method on_validation_batch_end (line 141) | def on_validation_batch_end(
method on_validation_epoch_end (line 149) | def on_validation_epoch_end(self, trainer, pl_module):
class LogF1PrecRecHeatmap (line 182) | class LogF1PrecRecHeatmap(Callback):
method __init__ (line 187) | def __init__(self, class_names: List[str] = None):
method on_sanity_check_start (line 192) | def on_sanity_check_start(self, trainer, pl_module):
method on_sanity_check_end (line 195) | def on_sanity_check_end(self, trainer, pl_module):
method on_validation_batch_end (line 199) | def on_validation_batch_end(
method on_validation_epoch_end (line 207) | def on_validation_epoch_end(self, trainer, pl_module):
class LogImagePredictions (line 245) | class LogImagePredictions(Callback):
method __init__ (line 251) | def __init__(self, num_samples: int = 8):
method on_sanity_check_start (line 256) | def on_sanity_check_start(self, trainer, pl_module):
method on_sanity_check_end (line 259) | def on_sanity_check_end(self, trainer, pl_module):
method on_validation_epoch_end (line 263) | def on_validation_epoch_end(self, trainer, pl_module):
FILE: training/src/datamodules/datasets/detokenizer.py
function wikitext_detokenize (line 10) | def wikitext_detokenize(string: str) -> str:
FILE: training/src/datamodules/datasets/lm_dataset.py
class LMDataset (line 10) | class LMDataset(torch.utils.data.Dataset):
method __init__ (line 12) | def __init__(self, tokens, seq_len, drop_last=True):
method __len__ (line 25) | def __len__(self):
method __getitem__ (line 28) | def __getitem__(self, idx):
FILE: training/src/datamodules/fault_tolerant_sampler.py
class RandomFaultTolerantSampler (line 9) | class RandomFaultTolerantSampler(RandomSampler):
method __init__ (line 11) | def __init__(self, *args, generator=None, **kwargs):
method state_dict (line 26) | def state_dict(self):
method load_state_dict (line 29) | def load_state_dict(self, state_dict):
method __iter__ (line 43) | def __iter__(self) -> Iterator[int]:
class FaultTolerantDistributedSampler (line 64) | class FaultTolerantDistributedSampler(DistributedSampler):
method __init__ (line 66) | def __init__(self, *args, **kwargs):
method state_dict (line 72) | def state_dict(self):
method load_state_dict (line 75) | def load_state_dict(self, state_dict):
method __iter__ (line 86) | def __iter__(self):
FILE: training/src/datamodules/imagenet.py
class DictDataset (line 17) | class DictDataset(Dataset):
method __init__ (line 19) | def __init__(self, dataset_dict, length=None):
method __getitem__ (line 28) | def __getitem__(self, index):
method __len__ (line 31) | def __len__(self):
function imagenet_normalization (line 36) | def imagenet_normalization():
class ImagenetDataModule (line 40) | class ImagenetDataModule(LightningDataModule):
method __init__ (line 63) | def __init__(
method num_classes (line 116) | def num_classes(self) -> int:
method _verify_splits (line 123) | def _verify_splits(self, data_dir: str, split: str) -> None:
method prepare_data (line 132) | def prepare_data(self) -> None:
method setup (line 139) | def setup(self, stage: Optional[str] = None) -> None:
method train_transform (line 164) | def train_transform(self) -> Callable:
method val_transform (line 188) | def val_transform(self) -> Callable:
method train_dataloader (line 212) | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
method val_dataloader (line 224) | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoade...
method test_dataloader (line 250) | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoad...
method _data_loader (line 256) | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: boo...
class Imagenet21kPDataModule (line 273) | class Imagenet21kPDataModule(ImagenetDataModule):
method num_classes (line 278) | def num_classes(self) -> int:
FILE: training/src/datamodules/language_modeling_hf.py
class SHMArray (line 29) | class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/us...
method __new__ (line 31) | def __new__(cls, input_array, shm=None):
method __array_finalize__ (line 36) | def __array_finalize__(self, obj):
class LMDataModule (line 41) | class LMDataModule(LightningDataModule):
method __init__ (line 42) | def __init__(self, dataset_name, tokenizer_name, dataset_config_name=N...
method prepare_data (line 80) | def prepare_data(self):
method setup (line 86) | def setup(self, stage=None):
method process_dataset (line 97) | def process_dataset(self):
method _save_to_cache (line 232) | def _save_to_cache(self, concat_ids, tokenizer, cache_dir):
method _load_from_cache (line 240) | def _load_from_cache(self, cache_dir):
method _cache_dir_name (line 250) | def _cache_dir_name(self):
method train_dataloader (line 253) | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
method val_dataloader (line 272) | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoade...
method test_dataloader (line 276) | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoad...
method _data_loader (line 280) | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: boo...
method load_state_dict (line 293) | def load_state_dict(self, checkpoint):
FILE: training/src/datamodules/timm_mixup.py
class TimmMixup (line 7) | class TimmMixup(Mixup):
method __call__ (line 10) | def __call__(self, x, target):
FILE: training/src/distributed/ddp_comm_hooks.py
function fp16_compress_hook (line 9) | def fp16_compress_hook(
FILE: training/src/eval.py
function remove_prefix (line 22) | def remove_prefix(text: str, prefix: str):
function load_checkpoint (line 28) | def load_checkpoint(path, device='cpu'):
function evaluate (line 47) | def evaluate(config: DictConfig) -> None:
FILE: training/src/metrics/accuracy.py
class AccuracyMine (line 7) | class AccuracyMine(Accuracy):
method update (line 10) | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
FILE: training/src/metrics/num_tokens.py
class NumTokens (line 9) | class NumTokens(Metric):
method __init__ (line 22) | def __init__(self, **kwargs: Dict[str, Any]):
method update (line 27) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
method compute (line 30) | def compute(self) -> Tensor:
method reset (line 33) | def reset(self):
method _forward_reduce_state_update (line 39) | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
FILE: training/src/metrics/perplexity.py
class Perplexity (line 21) | class Perplexity(Metric):
method __init__ (line 43) | def __init__(self, **kwargs: Dict[str, Any]):
method update (line 51) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
method compute (line 65) | def compute(self) -> Tensor:
FILE: training/src/models/modules/seq_common.py
function pooling (line 15) | def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=Tr...
class ClassificationHeadLinear (line 49) | class ClassificationHeadLinear(nn.Module):
method __init__ (line 52) | def __init__(self, d_model, num_classes, pooling_mode='MEAN',
method forward (line 60) | def forward(self, hidden_states, key_padding_mask=None, **kwargs):
class ClassificationHead (line 71) | class ClassificationHead(nn.Module):
method __init__ (line 74) | def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling...
method forward (line 84) | def forward(self, hidden_states, key_padding_mask=None, **kwargs):
class ClassificationHeadDual (line 99) | class ClassificationHeadDual(nn.Module):
method __init__ (line 102) | def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling...
method forward (line 114) | def forward(self, hidden_states1, hidden_states2,
class LMHead (line 134) | class LMHead(nn.Module):
method __init__ (line 136) | def __init__(self, d_model, num_classes, batch_first=True, bias=True):
method forward (line 140) | def forward(self, hidden_states, **kwargs):
function sinusoidal_init_ (line 148) | def sinusoidal_init_(tensor):
class PositionalEncoding (line 161) | class PositionalEncoding(nn.Module):
method __init__ (line 178) | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=Fal...
method forward (line 192) | def forward(self, x):
class Mlp (line 207) | class Mlp(nn.Module):
method __init__ (line 210) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 227) | def forward(self, x):
class MlpBig (line 236) | class MlpBig(nn.Module):
method __init__ (line 239) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 259) | def forward(self, x):
class GluMlp (line 262) | class GluMlp(nn.Module):
method __init__ (line 266) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method init_weights (line 276) | def init_weights(self):
method forward (line 282) | def forward(self, x):
class GatedMlp (line 292) | class GatedMlp(nn.Module):
method __init__ (line 295) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 311) | def forward(self, x):
class ConvMlp (line 321) | class ConvMlp(nn.Module):
method __init__ (line 324) | def __init__(
method forward (line 335) | def forward(self, x):
FILE: training/src/optim/param_grouping.py
function group_parameters_for_optimizer (line 15) | def group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_dec...
FILE: training/src/optim/timm_lr_scheduler.py
class TimmCosineLRScheduler (line 8) | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler....
method __init__ (line 13) | def __init__(self, *args, **kwargs):
method step (line 18) | def step(self, epoch=None):
FILE: training/src/tasks/seq.py
class SequenceModel (line 20) | class SequenceModel(LightningModule):
method __init__ (line 22) | def __init__(self, cfg, model_cfg=None):
method instantiate_datamodule (line 38) | def instantiate_datamodule(self):
method instantiate_model (line 47) | def instantiate_model(self):
method instantiate_loss (line 58) | def instantiate_loss(self):
method instantiate_metrics (line 66) | def instantiate_metrics(self):
method warmstart (line 79) | def warmstart(self):
method forward (line 90) | def forward(self, *args, **kwargs):
method step (line 93) | def step(self, batch: Any, is_train=True):
method shared_step (line 103) | def shared_step(self, batch: Any, batch_idx: int, phase='train'):
method training_step (line 117) | def training_step(self, batch: Any, batch_idx: int):
method validation_step (line 120) | def validation_step(self, batch: Any, batch_idx: int):
method test_step (line 123) | def test_step(self, batch: Any, batch_idx: int):
method configure_optimizers (line 126) | def configure_optimizers(self):
method optimizer_zero_grad (line 151) | def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_i...
method on_save_checkpoint (line 159) | def on_save_checkpoint(self, checkpoint):
class SequenceLMModel (line 169) | class SequenceLMModel(SequenceModel):
method step (line 171) | def step(self, batch: Any, is_train=True):
method shared_step (line 179) | def shared_step(self, batch: Any, batch_idx: int, phase='train'):
FILE: training/src/train.py
function last_modification_time (line 20) | def last_modification_time(path):
function train (line 32) | def train(config: DictConfig) -> Optional[float]:
FILE: training/src/utils/checkpoint.py
function load_checkpoint (line 8) | def load_checkpoint(path, device='cpu'):
function blockdiag_to_dense_mlp_bert (line 32) | def blockdiag_to_dense_mlp_bert(state_dict):
function interpolate_pos_embedding (line 41) | def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name...
function remove_model_prefix (line 68) | def remove_model_prefix(state_dict):
FILE: training/src/utils/ddp_zero1.py
function get_zero_optimizer_state_dict_local (line 24) | def get_zero_optimizer_state_dict_local(optimizer, global_rank):
class DDPStrategyZero1 (line 62) | class DDPStrategyZero1(DDPStrategy):
method optimizer_state (line 69) | def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:
method save_checkpoint (line 77) | def save_checkpoint(
method load_checkpoint (line 96) | def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
FILE: training/src/utils/ddp_zero2.py
class DistAdamNativeMixedPrecisionPlugin (line 26) | class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
method optimizer_step (line 28) | def optimizer_step( # type: ignore[override]
method clip_grad_by_norm (line 64) | def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val:...
class DDPStrategyZero2 (line 73) | class DDPStrategyZero2(DDPStrategy):
method __init__ (line 80) | def __init__(
method precision_plugin (line 92) | def precision_plugin(self) -> PrecisionPlugin:
method precision_plugin (line 96) | def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]...
method optimizer_state (line 106) | def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:
method save_checkpoint (line 114) | def save_checkpoint(
method load_checkpoint (line 133) | def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
FILE: training/src/utils/distributed.py
function init_distributed (line 23) | def init_distributed(cuda):
function barrier (line 39) | def barrier():
function get_rank (line 47) | def get_rank():
function get_world_size (line 58) | def get_world_size():
function all_reduce_item (line 70) | def all_reduce_item(value, op='sum'):
function sync_workers (line 105) | def sync_workers():
FILE: training/src/utils/ema.py
function to_float_maybe (line 13) | def to_float_maybe(x):
class ExponentialMovingAverage (line 19) | class ExponentialMovingAverage:
method __init__ (line 29) | def __init__(
method _get_parameters (line 50) | def _get_parameters(
method update (line 76) | def update(
method copy_to (line 106) | def copy_to(
method store (line 123) | def store(
method restore (line 141) | def restore(
method average_parameters (line 168) | def average_parameters(
method to (line 195) | def to(self, device=None, dtype=None) -> None:
method state_dict (line 216) | def state_dict(self) -> dict:
method load_state_dict (line 228) | def load_state_dict(self, state_dict: dict) -> None:
FILE: training/src/utils/flops.py
function profile_deepspeed (line 20) | def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch...
function profile_fvcore (line 35) | def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.fl...
FILE: training/src/utils/gpu_affinity.py
function systemGetDriverVersion (line 12) | def systemGetDriverVersion():
function deviceGetCount (line 16) | def deviceGetCount():
class device (line 20) | class device:
method __init__ (line 24) | def __init__(self, device_idx):
method getName (line 28) | def getName(self):
method getCpuAffinity (line 31) | def getCpuAffinity(self):
function set_socket_affinity (line 45) | def set_socket_affinity(gpu_id):
function set_single_affinity (line 51) | def set_single_affinity(gpu_id):
function set_single_unique_affinity (line 57) | def set_single_unique_affinity(gpu_id, nproc_per_node):
function set_socket_unique_affinity (line 80) | def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
function get_thread_siblings_list (line 113) | def get_thread_siblings_list():
function set_affinity (line 127) | def set_affinity(gpu_id, nproc_per_node, mode='socket'):
FILE: training/src/utils/utils.py
class LoggingContext (line 13) | class LoggingContext:
method __init__ (line 14) | def __init__(self, logger, level=None, handler=None, close=True):
method __enter__ (line 20) | def __enter__(self):
method __exit__ (line 27) | def __exit__(self, et, ev, tb):
function get_logger (line 37) | def get_logger(name=__name__) -> logging.Logger:
function extras (line 50) | def extras(config: DictConfig) -> None:
function print_config (line 89) | def print_config(
function finish (line 131) | def finish(
FILE: training/tests/datamodules/test_language_modeling_hf.py
function div_up (line 19) | def div_up(x: int, y: int) -> int:
function num_cpu_cores (line 24) | def num_cpu_cores():
class TestLMDataModule (line 32) | class TestLMDataModule:
method test_wikitext2 (line 34) | def test_wikitext2(self):
method test_wikitext103 (line 64) | def test_wikitext103(self):
method test_openwebtext (line 94) | def test_openwebtext(self):
method test_lambada (line 125) | def test_lambada(self):
method test_the_pile (line 156) | def test_the_pile(self):
method test_pg19 (line 188) | def test_pg19(self):
Condensed preview — 994 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (5,533K chars).
[
{
"path": ".github/workflows/README.md",
"chars": 1211,
"preview": "# GitHub Workflow Tagging Flow\n\nThis repository uses separate tag lanes so FA2 and FA4 publishing do not collide.\n\n## Re"
},
{
"path": ".github/workflows/_build.yml",
"chars": 9943,
"preview": "name: ~Build wheel template\n\non:\n workflow_call:\n inputs:\n runs-on:\n description: \"The runner to use for"
},
{
"path": ".github/workflows/build.yml",
"chars": 1384,
"preview": "name: Build wheels\n\non:\n workflow_dispatch:\n inputs:\n runs-on:\n description: \"The runner to use for the "
},
{
"path": ".github/workflows/pre-commit.yaml",
"chars": 797,
"preview": "name: Lint\n\non:\n pull_request:\n paths:\n - 'flash_attn/cute/flash_bwd_sm90.py'\n - 'flash_attn/cute/flash_bw"
},
{
"path": ".github/workflows/publish-fa4.yml",
"chars": 1561,
"preview": "name: Publish flash-attn-4 to PyPI\n\non:\n push:\n tags:\n - 'fa4-v*'\n\npermissions:\n contents: write\n\njobs:\n buil"
},
{
"path": ".github/workflows/publish.yml",
"chars": 4008,
"preview": "# This workflow will:\n# - Create a new Github release\n# - Build wheels for supported architectures\n# - Deploy the wheels"
},
{
"path": ".gitignore",
"chars": 371,
"preview": "*.ncu-rep\n*.sass\n*.ptx\n*.cubin\n*.plk\n.DS_store\n.vscode\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n\n"
},
{
"path": ".gitmodules",
"chars": 334,
"preview": "[submodule \"csrc/cutlass\"]\n\tpath = csrc/cutlass\n\turl = https://github.com/NVIDIA/cutlass.git\n[submodule \"csrc/composable"
},
{
"path": ".pre-commit-config.yaml",
"chars": 476,
"preview": "repos:\n - repo: https://github.com/astral-sh/ruff-pre-commit\n rev: v0.11.13\n hooks:\n - id: ruff-check\n "
},
{
"path": "AI/DEBUG_2CTA.md",
"chars": 5923,
"preview": "# Debugging GPU Kernel Hangs (Deadlocks) in CUTLASS DSL / 2CTA Kernels\n\n## General Approach to Debugging Kernel Hangs\n\n#"
},
{
"path": "AI/RACECHECK_TMA_HAZARD.md",
"chars": 7826,
"preview": "# compute-sanitizer racecheck hazard with `cp.async.bulk`\n\n## Summary\n\n`compute-sanitizer --tool=racecheck` reports fals"
},
{
"path": "AI/SM90_BLOCK_SIZE_TUNING.md",
"chars": 6408,
"preview": "# SM90 Block Size Tuning Guide\n\nHow to choose tile sizes and MMA configurations for FlashAttention on Hopper (SM90).\n\n##"
},
{
"path": "AI/SM90_R2P_MASKING_SASS.md",
"chars": 4397,
"preview": "# SM90 FWD R2P Masking — SASS Investigation\n\n## SASS Instruction Counts (hdim=128, seqlen=113, tile_n=128)\n\nWith tile_n="
},
{
"path": "AI/VARLEN_PREPROCESS_TILE_BUG.md",
"chars": 1702,
"preview": "# Varlen Preprocess Tile Mismatch Bug\n\n## Summary\n\n`SeqlenInfo.create` in `flash_bwd_preprocess.py` defaulted `tile=128`"
},
{
"path": "AI/racecheck_repro_1d_bulk.py",
"chars": 3133,
"preview": "\"\"\"Minimal reproducer: cp.async.bulk (raw address) triggers racecheck hazard.\n\nWarp 0 loads via cp.async.bulk, warp 1 re"
},
{
"path": "AI/racecheck_repro_1d_tensor.py",
"chars": 3508,
"preview": "\"\"\"Minimal reproducer: cp.async.bulk.tensor.1d (descriptor TMA) passes racecheck.\n\nSame pipeline as racecheck_repro_1d_b"
},
{
"path": "AUTHORS",
"chars": 29,
"preview": "Tri Dao, trid@cs.stanford.edu"
},
{
"path": "CLAUDE.md",
"chars": 6624,
"preview": "# CLAUDE.md\n\nThis file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.\n\n## "
},
{
"path": "LICENSE",
"chars": 1558,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2022, the respective contributors, as shown by the AUTHORS file.\nAll rights reserved"
},
{
"path": "MANIFEST.in",
"chars": 343,
"preview": "recursive-include csrc *.cu\nrecursive-include csrc *.h\nrecursive-include csrc *.cuh\nrecursive-include csrc *.cpp\nrecursi"
},
{
"path": "Makefile",
"chars": 126,
"preview": "\nclean_dist:\n\trm -rf dist/*\n\ncreate_dist: clean_dist\n\tpython setup.py sdist\n\nupload_package: create_dist\n\ttwine upload d"
},
{
"path": "README.md",
"chars": 24119,
"preview": "# FlashAttention\nThis repository provides the official implementation of FlashAttention and\nFlashAttention-2 from the\nfo"
},
{
"path": "benchmarks/bench_sm90.py",
"chars": 19877,
"preview": "#!/usr/bin/env python\n\"\"\"Unified SM90 benchmark for forward and backward passes.\n\nUsage:\n # Default: bench fwd+bwd fo"
},
{
"path": "benchmarks/benchmark_alibi.py",
"chars": 11039,
"preview": "# Copyright (c) 2024, Sanghun Cho, Tri Dao.\n\nimport pickle\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.n"
},
{
"path": "benchmarks/benchmark_attn.py",
"chars": 18225,
"preview": "import argparse\nimport time\nimport torch\n\ntry:\n import cudnn\nexcept ImportError:\n cudnn = None\n\nfrom einops import"
},
{
"path": "benchmarks/benchmark_causal.py",
"chars": 8882,
"preview": "from functools import partial\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einop"
},
{
"path": "benchmarks/benchmark_flash_attention.py",
"chars": 8118,
"preview": "# Install the newest triton version with\n# pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory"
},
{
"path": "benchmarks/benchmark_gemm.py",
"chars": 1752,
"preview": "import time\nimport torch\nimport torch.utils.benchmark as benchmark\n\nfrom triton.testing import do_bench\n\nif torch.versio"
},
{
"path": "csrc/flash_attn/flash_api.cpp",
"chars": 70407,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/alibi.h",
"chars": 2874,
"preview": "#include <cmath>\n\n#include \"namespace_config.h\"\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutla"
},
{
"path": "csrc/flash_attn/src/block_info.h",
"chars": 2476,
"preview": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/dropout.h",
"chars": 5843,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/flash.h",
"chars": 6085,
"preview": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_bwd_kernel.h",
"chars": 49186,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/flash_attn/src/flash_bwd_launch_template.h",
"chars": 16796,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/flash_bwd_preprocess_kernel.h",
"chars": 21098,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"chars": 489,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"chars": 481,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"chars": 485,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"chars": 487,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"chars": 477,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"chars": 479,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_kernel.h",
"chars": 76725,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/flash_fwd_launch_template.h",
"chars": 19961,
"preview": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"chars": 431,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"chars": 427,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"chars": 431,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"chars": 427,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"chars": 431,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"chars": 427,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"chars": 429,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"chars": 425,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"chars": 429,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"chars": 425,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"chars": 429,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"chars": 430,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"chars": 425,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
"chars": 426,
"preview": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n/"
},
{
"path": "csrc/flash_attn/src/generate_kernels.py",
"chars": 3461,
"preview": "import argparse\nimport itertools\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Opt"
},
{
"path": "csrc/flash_attn/src/hardware_info.h",
"chars": 1634,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/kernel_traits.h",
"chars": 17697,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/mask.h",
"chars": 11284,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/namespace_config.h",
"chars": 1764,
"preview": "/**\n * @file flash_namespace_config.h\n * @brief Configuration file for Flash namespace management and isolation\n *\n * Th"
},
{
"path": "csrc/flash_attn/src/philox.cuh",
"chars": 1730,
"preview": "// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f"
},
{
"path": "csrc/flash_attn/src/philox_unpack.cuh",
"chars": 163,
"preview": "// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh\n\n#pragma once"
},
{
"path": "csrc/flash_attn/src/rotary.h",
"chars": 9002,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/softmax.h",
"chars": 9464,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn/src/static_switch.h",
"chars": 3795,
"preview": "// Inspired by\n// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/p"
},
{
"path": "csrc/flash_attn/src/utils.h",
"chars": 18486,
"preview": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/flash_api.cpp",
"chars": 8254,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/flash_common.cpp",
"chars": 1081,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/flash_common.hpp",
"chars": 3329,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/mha_bwd.cpp",
"chars": 18366,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/mha_fwd.cpp",
"chars": 15402,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
"chars": 26148,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
"chars": 19763,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/flash_attn_ck/mha_varlen_fwd.cpp",
"chars": 25075,
"preview": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n *******"
},
{
"path": "csrc/fused_dense_lib/README.md",
"chars": 491,
"preview": "This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu\n(forward and b"
},
{
"path": "csrc/fused_dense_lib/fused_dense.cpp",
"chars": 9969,
"preview": "// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp\n// We make it work for bfloat16\n#include"
},
{
"path": "csrc/fused_dense_lib/fused_dense_cuda.cu",
"chars": 24690,
"preview": "// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu\n#include <ATen/ATen.h>\n#include <ATe"
},
{
"path": "csrc/fused_dense_lib/setup.py",
"chars": 1262,
"preview": "import os\nimport subprocess\nfrom packaging.version import parse, Version\n\nimport torch\nfrom setuptools import setup\nfrom"
},
{
"path": "csrc/layer_norm/README.md",
"chars": 865,
"preview": "This CUDA extension implements fused dropout + residual + LayerNorm, building on\nApex's [FastLayerNorm](https://github.c"
},
{
"path": "csrc/layer_norm/ln.h",
"chars": 7248,
"preview": "#pragma once\n\n#include <unordered_map>\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\n#ifdef OLD_GENERATOR_PATH\n#include"
},
{
"path": "csrc/layer_norm/ln_api.cpp",
"chars": 36132,
"preview": "#include <torch/extension.h>\n#include \"ATen/cuda/CUDAContext.h\"\n#include <c10/cuda/CUDAGuard.h>\n\n#include \"ln.h\"\n\n/*\n\nSu"
},
{
"path": "csrc/layer_norm/ln_bwd_1024.cu",
"chars": 987,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_1280.cu",
"chars": 987,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_1536.cu",
"chars": 977,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_2048.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_256.cu",
"chars": 977,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_2560.cu",
"chars": 977,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_3072.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_4096.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_512.cu",
"chars": 977,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_5120.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_6144.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_7168.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_768.cu",
"chars": 977,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_8192.cu",
"chars": 976,
"preview": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE,"
},
{
"path": "csrc/layer_norm/ln_bwd_kernels.cuh",
"chars": 25647,
"preview": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n\nnamespac"
},
{
"path": "csrc/layer_norm/ln_fwd_1024.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_1280.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_1536.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_2048.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_256.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_2560.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_3072.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_4096.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_512.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_5120.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_6144.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_7168.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_768.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_8192.cu",
"chars": 925,
"preview": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HIDDEN_SIZE, WTYPE, "
},
{
"path": "csrc/layer_norm/ln_fwd_kernels.cuh",
"chars": 12721,
"preview": "#pragma once\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl."
},
{
"path": "csrc/layer_norm/ln_kernel_traits.h",
"chars": 6655,
"preview": "#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nname"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_1024.cu",
"chars": 1095,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_1280.cu",
"chars": 1095,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_1536.cu",
"chars": 1085,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_2048.cu",
"chars": 1084,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_256.cu",
"chars": 1085,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_2560.cu",
"chars": 1085,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_3072.cu",
"chars": 1084,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_4096.cu",
"chars": 1145,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_512.cu",
"chars": 1085,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_5120.cu",
"chars": 1145,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_6144.cu",
"chars": 1084,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_7168.cu",
"chars": 1084,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_768.cu",
"chars": 1085,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_bwd_8192.cu",
"chars": 1084,
"preview": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n// H"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_1024.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_1280.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_1536.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_2048.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_256.cu",
"chars": 1032,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_2560.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_3072.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_4096.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_512.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_5120.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_6144.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_7168.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_768.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_fwd_8192.cu",
"chars": 1033,
"preview": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n// HI"
},
{
"path": "csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh",
"chars": 24916,
"preview": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n#include "
},
{
"path": "csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh",
"chars": 12530,
"preview": "#pragma once\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl."
},
{
"path": "csrc/layer_norm/ln_utils.cuh",
"chars": 29989,
"preview": "#pragma once\n\n#include <cassert>\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#include \"ln.h\"\n\n//////////////////////"
},
{
"path": "csrc/layer_norm/setup.py",
"chars": 8032,
"preview": "# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py\nimport sys\nimport warnings\nimport os\nfrom packaging.v"
},
{
"path": "csrc/layer_norm/static_switch.h",
"chars": 1278,
"preview": "// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pyto"
}
]
// ... and 794 more files (download for full content)
About this extraction
This page contains the full source code of the Dao-AILab/flash-attention GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 994 files (5.1 MB), approximately 1.4M tokens, and a symbol index with 2142 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.